Compare commits
240 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41ca17acda | ||
|
|
06b31f51eb | ||
|
|
ece02db6a3 | ||
|
|
939a7ebf8b | ||
|
|
61edb70fff | ||
|
|
4e455b8aab | ||
|
|
9434390ad3 | ||
|
|
65250db92c | ||
|
|
416dce7975 | ||
|
|
0c5365e7c6 | ||
|
|
19e9d76610 | ||
|
|
e7b05b0138 | ||
|
|
818c9c37ca | ||
|
|
714fb3b14a | ||
|
|
0af379c465 | ||
|
|
9c5bb5df19 | ||
|
|
dc6ea79036 | ||
|
|
21bbb59e31 | ||
|
|
12a69205ed | ||
|
|
1f684cdd97 | ||
|
|
3467109668 | ||
|
|
971f8473eb | ||
|
|
8434ef5efc | ||
|
|
290470dd60 | ||
|
|
425ac7b51d | ||
|
|
0382cfbeba | ||
|
|
9b1e061b32 | ||
|
|
b4abc158b9 | ||
|
|
5832d7433d | ||
|
|
3736458503 | ||
|
|
374618e050 | ||
|
|
543972ef38 | ||
|
|
73f36cc0ef | ||
|
|
a7db39d999 | ||
|
|
a153e11fe0 | ||
|
|
ca6f9246cc | ||
|
|
d080d675a8 | ||
|
|
40bff38933 | ||
|
|
2fe3ca0188 | ||
|
|
545ea15c9a | ||
|
|
8cbaeecc75 | ||
|
|
70e854b346 | ||
|
|
d55490cd27 | ||
|
|
1fa9e1f656 | ||
|
|
994f30e1ed | ||
|
|
b22478c0b4 | ||
|
|
94c34efd90 | ||
|
|
32099b9275 | ||
|
|
9fc6654a4a | ||
|
|
d24c110d55 | ||
|
|
4dd5d8bf8a | ||
|
|
cd9a32a36b | ||
|
|
6caf3e0485 | ||
|
|
93f002cafb | ||
|
|
c5e30c2c07 | ||
|
|
1c2afb8bd2 | ||
|
|
674b20d3af | ||
|
|
a5503308c5 | ||
|
|
e61afdefa3 | ||
|
|
426d70a790 | ||
|
|
b03a212fbf | ||
|
|
1833e7c921 | ||
|
|
777ec63a71 | ||
|
|
0a6e5ae9c1 | ||
|
|
ee448a37e9 | ||
|
|
9c051052b0 | ||
|
|
4d7c487614 | ||
|
|
65025cc448 | ||
|
|
bbba1d9bb7 | ||
|
|
99dc96c644 | ||
|
|
2a27d2030a | ||
|
|
cd160caaa1 | ||
|
|
d27b5eb23e | ||
|
|
f9d704a900 | ||
|
|
2f6e00f512 | ||
|
|
5aa312e437 | ||
|
|
ebaf36a8be | ||
|
|
babe93b99a | ||
|
|
a4e9f3cab7 | ||
|
|
b06866877a | ||
|
|
967cdfebc8 | ||
|
|
3c11c60126 | ||
|
|
2963e8a757 | ||
|
|
cb2d4ea88a | ||
|
|
add7ea07ee | ||
|
|
da8726b2cb | ||
|
|
3358877054 | ||
|
|
1f7798c7c1 | ||
|
|
c7b3bb5e58 | ||
|
|
f661f21675 | ||
|
|
b6164aa59b | ||
|
|
4209d7f7c0 | ||
|
|
334b338ab0 | ||
|
|
72f33be6f2 | ||
|
|
84890b8e61 | ||
|
|
c6668adcf3 | ||
|
|
a178ed5c22 | ||
|
|
7601c74c9c | ||
|
|
fad9ee4d21 | ||
|
|
d1a9913c47 | ||
|
|
e4ca2623cb | ||
|
|
9c1bf37960 | ||
|
|
f46528471b | ||
|
|
191680940b | ||
|
|
ee02afec56 | ||
|
|
a458028de2 | ||
|
|
abd8f2c269 | ||
|
|
f3ad4e39e4 | ||
|
|
e0a5cbf0e7 | ||
|
|
953697cd86 | ||
|
|
3bd2122eb4 | ||
|
|
50b0527858 | ||
|
|
b044fcdec2 | ||
|
|
b0508fcf2c | ||
|
|
ce89b0aebc | ||
|
|
d5008ed828 | ||
|
|
d467716e26 | ||
|
|
199e21b3ef | ||
|
|
1d926f2e67 | ||
|
|
4a71a391b8 | ||
|
|
d3ed4e46e2 | ||
|
|
057a1026d7 | ||
|
|
1ba171a58d | ||
|
|
1adac67155 | ||
|
|
42be1a3773 | ||
|
|
0a49fafa0d | ||
|
|
4a5d5e1f3b | ||
|
|
583a2ec2e4 | ||
|
|
19765e89e9 | ||
|
|
9895bc83bf | ||
|
|
ab98c31f16 | ||
|
|
f9c9c4188a | ||
|
|
c21d2302e7 | ||
|
|
4ed62e181d | ||
|
|
52a755a08c | ||
|
|
9a8d3cbd90 | ||
|
|
b101ce06bd | ||
|
|
c83fd179a8 | ||
|
|
5258305745 | ||
|
|
ce781831ee | ||
|
|
58297daf6d | ||
|
|
3393a08f7e | ||
|
|
5b2ddeccdb | ||
|
|
26cc1072dd | ||
|
|
12973711f6 | ||
|
|
909ac9dd41 | ||
|
|
d94a07d417 | ||
|
|
b32dd8bfc4 | ||
|
|
9feb0e597b | ||
|
|
9dab84a573 | ||
|
|
d089c7fce0 | ||
|
|
253a080df5 | ||
|
|
0c6e4b2aee | ||
|
|
e14bbde77d | ||
|
|
7496163467 | ||
|
|
696a94d1ce | ||
|
|
2699b0974c | ||
|
|
90c0250ba4 | ||
|
|
eb96153ffd | ||
|
|
47e3eb9b5b | ||
|
|
b8b07adeef | ||
|
|
d0e9e37ef6 | ||
|
|
820f92d8cb | ||
|
|
e42523af84 | ||
|
|
e2184d5e06 | ||
|
|
7fe0353260 | ||
|
|
0f2eba507e | ||
|
|
55e08474f3 | ||
|
|
28bdc52e1d | ||
|
|
e4221fa6c3 | ||
|
|
1652db9a2d | ||
|
|
601f17653a | ||
|
|
7718190fcd | ||
|
|
349c7dcb9e | ||
|
|
1c42b867cf | ||
|
|
d4771e563e | ||
|
|
b0a5fc0693 | ||
|
|
3b96fb8776 | ||
|
|
7f93c4b978 | ||
|
|
15c3df1cba | ||
|
|
7fb8e66c01 | ||
|
|
728e1f1290 | ||
|
|
87b9ed6ecd | ||
|
|
38b4ebe8ba | ||
|
|
d098af3185 | ||
|
|
4e56130a40 | ||
|
|
2bbdc70187 | ||
|
|
b678a55f63 | ||
|
|
5491964e81 | ||
|
|
b05297a96d | ||
|
|
197293e25e | ||
|
|
ba41c4ab56 | ||
|
|
bda72b8bc0 | ||
|
|
bb6b9f4cb1 | ||
|
|
e40b5a3ea0 | ||
|
|
4cfed6e98e | ||
|
|
687e3dd5e2 | ||
|
|
e4140cd299 | ||
|
|
8e056cbdf2 | ||
|
|
9dcfb38967 | ||
|
|
47b9235d70 | ||
|
|
f3cd53a4db | ||
|
|
dbdb4ea66c | ||
|
|
00424d7ca3 | ||
|
|
4b738d6f63 | ||
|
|
8a5e2adb1e | ||
|
|
f85329e112 | ||
|
|
46efbdf1d9 | ||
|
|
8885ade003 | ||
|
|
2564928d83 | ||
|
|
56114d3071 | ||
|
|
5b9977c9af | ||
|
|
12a544164f | ||
|
|
2ca1156b7e | ||
|
|
3ad3683ca7 | ||
|
|
1599bd87a0 | ||
|
|
90623400a4 | ||
|
|
64e44fb24f | ||
|
|
156b9a133f | ||
|
|
df8cb23848 | ||
|
|
9ff513093b | ||
|
|
17184e552c | ||
|
|
aad2c55d8c | ||
|
|
2f177c4a3b | ||
|
|
b362eccb23 | ||
|
|
5daaf77258 | ||
|
|
36cc4412c3 | ||
|
|
e1d4bf7e94 | ||
|
|
62bf28949e | ||
|
|
25526b3aa2 | ||
|
|
1e3fab9550 | ||
|
|
f25de6d8a4 | ||
|
|
8a175e79d8 | ||
|
|
dc37b44486 | ||
|
|
2d1df92aa7 | ||
|
|
2c1a603e38 | ||
|
|
774cee036b | ||
|
|
d22916988e | ||
|
|
5b8ad94dde | ||
|
|
f668570292 |
29
.gitignore
vendored
@@ -55,22 +55,6 @@ coverage.xml
|
|||||||
*.mo
|
*.mo
|
||||||
*.pot
|
*.pot
|
||||||
|
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
db.sqlite3-journal
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
target/
|
target/
|
||||||
|
|
||||||
@@ -129,4 +113,15 @@ dmypy.json
|
|||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
*.wav
|
*.wav
|
||||||
run_*.sh
|
run_*.sh
|
||||||
|
|
||||||
|
# Downloaded models
|
||||||
|
*.pt
|
||||||
|
|
||||||
|
# Debug & testing
|
||||||
|
test_*.py
|
||||||
|
launch.json
|
||||||
|
.DS_Store
|
||||||
|
test/*
|
||||||
|
nllb-200-distilled-600M-ctranslate2/*
|
||||||
|
*.mp3
|
||||||
@@ -15,7 +15,7 @@ Thank you for considering contributing ! We appreciate your time and effort to h
|
|||||||
|
|
||||||
## Opening Issues
|
## Opening Issues
|
||||||
|
|
||||||
If you encounter a problem with diart or want to suggest an improvement, please follow these guidelines when opening an issue:
|
If you encounter a problem with WhisperLiveKit or want to suggest an improvement, please follow these guidelines when opening an issue:
|
||||||
|
|
||||||
- **Bug Reports:**
|
- **Bug Reports:**
|
||||||
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
|
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
|
||||||
@@ -43,4 +43,4 @@ We welcome and appreciate contributions! To ensure a smooth review process, plea
|
|||||||
|
|
||||||
## Thank You
|
## Thank You
|
||||||
|
|
||||||
Your contributions make diart better for everyone. Thank you for your time and dedication!
|
Your contributions make WhisperLiveKit better for everyone. Thank you for your time and dedication!
|
||||||
|
|||||||
91
DEV_NOTES.md
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# 1. Simulstreaming: Decouple the encoder for faster inference
|
||||||
|
|
||||||
|
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
|
||||||
|
|
||||||
|
On macOS Apple Silicon M4 :
|
||||||
|
|
||||||
|
| Encoder | base.en | small |
|
||||||
|
|--------|---------|-------|
|
||||||
|
| WHISPER (no modification) | 0.35s | 1.09s |
|
||||||
|
| FASTER_WHISPER | 0.4s | 1.20s |
|
||||||
|
| MLX_WHISPER | 0.07s | 0.20s |
|
||||||
|
|
||||||
|
Memory saved by only loading encoder for optimized framework:
|
||||||
|
|
||||||
|
For tiny.en, mlx whisper:
|
||||||
|
Sizes MLX whisper:
|
||||||
|
Decoder weights: 59110771 bytes
|
||||||
|
Encoder weights: 15268874 bytes
|
||||||
|
|
||||||
|
|
||||||
|
# 2. Translation: Faster model for each system
|
||||||
|
|
||||||
|
## Benchmark Results
|
||||||
|
|
||||||
|
Testing on MacBook M3 with NLLB-200-distilled-600M model:
|
||||||
|
|
||||||
|
### Standard Transformers vs CTranslate2
|
||||||
|
|
||||||
|
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|
||||||
|
|-----------|-------------------------|---------------------------|---------|
|
||||||
|
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
|
||||||
|
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
|
||||||
|
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
|
||||||
|
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
|
||||||
|
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
|
||||||
|
|
||||||
|
**Results:**
|
||||||
|
- Total Standard time: 4.1068s
|
||||||
|
- Total CTranslate2 time: 8.5476s
|
||||||
|
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
|
||||||
|
|
||||||
|
|
||||||
|
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||||
|
|
||||||
|
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||||
|
|
||||||
|
## Problem Statement
|
||||||
|
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
|
||||||
|
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
|
||||||
|
|
||||||
|
#
|
||||||
|
### Initial Setup
|
||||||
|
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
|
||||||
|
|
||||||
|
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
|
||||||
|
|
||||||
|
### Algorithm
|
||||||
|
|
||||||
|
```python
|
||||||
|
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
|
||||||
|
```
|
||||||
|
|
||||||
|
- `DS_a_{i}`: Top detected speaker for prediction i
|
||||||
|
- `DS_b_{i}`: Second detected speaker for prediction i
|
||||||
|
- `AS_{i}`: Attributed speaker for prediction i
|
||||||
|
- `GTS_A`: Ground truth speaker A
|
||||||
|
- `GTS_B`: Ground truth speaker B
|
||||||
|
- `DIST(a, b)`: Distance between detected speakers a and b
|
||||||
|
|
||||||
|
3. **Attribution Logic**
|
||||||
|
|
||||||
|
```
|
||||||
|
AS_0 ← A
|
||||||
|
|
||||||
|
AS_1 ← B
|
||||||
|
|
||||||
|
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
|
||||||
|
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
|
||||||
|
# Likely that DS_a_0 = DS_a_1 (same speaker)
|
||||||
|
AS_1 ← A
|
||||||
|
AS_2 ← B
|
||||||
|
|
||||||
|
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
|
||||||
|
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
|
||||||
|
AS_2 ← A
|
||||||
|
|
||||||
|
ELSE:
|
||||||
|
AS_2 ← B
|
||||||
|
|
||||||
|
to finish
|
||||||
|
```
|
||||||
49
Dockerfile
@@ -1,4 +1,4 @@
|
|||||||
FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04
|
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
|
||||||
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
@@ -9,46 +9,50 @@ ARG EXTRAS
|
|||||||
ARG HF_PRECACHE_DIR
|
ARG HF_PRECACHE_DIR
|
||||||
ARG HF_TKN_FILE
|
ARG HF_TKN_FILE
|
||||||
|
|
||||||
# Install system dependencies
|
|
||||||
#RUN apt-get update && \
|
|
||||||
# apt-get install -y ffmpeg git && \
|
|
||||||
# apt-get clean && \
|
|
||||||
# rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# 2) Install system dependencies + Python + pip
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
python3 \
|
python3 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
|
python3-venv \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
git && \
|
git \
|
||||||
|
build-essential \
|
||||||
|
python3-dev \
|
||||||
|
ca-certificates && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
RUN python3 -m venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
# timeout/retries for large torch wheels
|
||||||
|
RUN pip3 install --upgrade pip setuptools wheel && \
|
||||||
|
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
|
||||||
|
--index-url https://download.pytorch.org/whl/cu129 \
|
||||||
|
torch torchaudio \
|
||||||
|
|| (echo "Initial install failed — retrying with extended timeout..." && \
|
||||||
|
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
|
||||||
|
--index-url https://download.pytorch.org/whl/cu129 \
|
||||||
|
torch torchvision torchaudio)
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||||
# Note: For gates modedls, need to add your HF toke. See README.md
|
|
||||||
# for more details.
|
|
||||||
RUN if [ -n "$EXTRAS" ]; then \
|
RUN if [ -n "$EXTRAS" ]; then \
|
||||||
echo "Installing with extras: [$EXTRAS]"; \
|
echo "Installing with extras: [$EXTRAS]"; \
|
||||||
pip install --no-cache-dir .[$EXTRAS]; \
|
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
echo "Installing base package only"; \
|
echo "Installing base package only"; \
|
||||||
pip install --no-cache-dir .; \
|
pip install --no-cache-dir whisperlivekit; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Enable in-container caching for Hugging Face models by:
|
# In-container caching for Hugging Face models by:
|
||||||
# Note: If running multiple containers, better to map a shared
|
|
||||||
# bucket.
|
|
||||||
#
|
|
||||||
# A) Make the cache directory persistent via an anonymous volume.
|
# A) Make the cache directory persistent via an anonymous volume.
|
||||||
# Note: This only persists for a single, named container. This is
|
# Note: This only persists for a single, named container. This is
|
||||||
# only for convenience at de/test stage.
|
# only for convenience at de/test stage.
|
||||||
# For prod, it is better to use a named volume via host mount/k8s.
|
# For prod, it is better to use a named volume via host mount/k8s.
|
||||||
VOLUME ["/root/.cache/huggingface/hub"]
|
VOLUME ["/root/.cache/huggingface/hub"]
|
||||||
|
|
||||||
|
|
||||||
# or
|
# or
|
||||||
# B) Conditionally copy a local pre-cache from the build context to the
|
# B) Conditionally copy a local pre-cache from the build context to the
|
||||||
# container's cache via the HF_PRECACHE_DIR build-arg.
|
# container's cache via the HF_PRECACHE_DIR build-arg.
|
||||||
@@ -63,8 +67,7 @@ RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
|||||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Conditionally copy a Hugging Face token if provided
|
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
|
||||||
|
|
||||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||||
mkdir -p /root/.cache/huggingface && \
|
mkdir -p /root/.cache/huggingface && \
|
||||||
@@ -72,11 +75,9 @@ RUN if [ -n "$HF_TKN_FILE" ]; then \
|
|||||||
else \
|
else \
|
||||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Expose port for the transcription server
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||||
|
|
||||||
# Default args
|
CMD ["--model", "medium"]
|
||||||
CMD ["--model", "tiny.en"]
|
|
||||||
|
|||||||
61
Dockerfile.cpu
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
FROM python:3.13-slim
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ARG EXTRAS
|
||||||
|
ARG HF_PRECACHE_DIR
|
||||||
|
ARG HF_TKN_FILE
|
||||||
|
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
ffmpeg \
|
||||||
|
git \
|
||||||
|
build-essential \
|
||||||
|
python3-dev && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install CPU-only PyTorch
|
||||||
|
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||||
|
RUN if [ -n "$EXTRAS" ]; then \
|
||||||
|
echo "Installing with extras: [$EXTRAS]"; \
|
||||||
|
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||||
|
else \
|
||||||
|
echo "Installing base package only"; \
|
||||||
|
pip install --no-cache-dir whisperlivekit; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Enable in-container caching for Hugging Face models
|
||||||
|
VOLUME ["/root/.cache/huggingface/hub"]
|
||||||
|
|
||||||
|
# Conditionally copy a local pre-cache from the build context
|
||||||
|
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
||||||
|
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
||||||
|
mkdir -p /root/.cache/huggingface/hub && \
|
||||||
|
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
||||||
|
else \
|
||||||
|
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Conditionally copy a Hugging Face token if provided
|
||||||
|
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||||
|
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||||
|
mkdir -p /root/.cache/huggingface && \
|
||||||
|
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
||||||
|
else \
|
||||||
|
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Expose port for the transcription server
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||||
|
|
||||||
|
# Default args - you might want to use a smaller model for CPU
|
||||||
|
CMD ["--model", "tiny"]
|
||||||
224
LICENSE
@@ -1,28 +1,210 @@
|
|||||||
MIT License
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
Copyright (c) 2025 Quentin Fuxa.
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
1. Definitions.
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
copies or substantial portions of the Software.
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
the copyright owner that is granting the License.
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2025 Quentin Fuxa
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
---
|
---
|
||||||
|
|
||||||
Based on:
|
## Based on:
|
||||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming. The original work by ÚFAL. License: https://github.com/ufal/whisper_streaming/blob/main/LICENSE
|
- **SimulWhisper** by Speech and Audio Technology LAB of Tsinghua University – Apache-2.0 – https://github.com/ufal/SimulStreaming
|
||||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad. The work by Snakers4 (silero-vad). License: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
- **SimulStreaming** by ÚFAL – MIT License – https://github.com/ufal/SimulStreaming
|
||||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart. The work in Diart by juanmc2005. License: https://github.com/juanmc2005/diart/blob/main/LICENSE
|
- **NeMo** by NVidia - Apache-2.0 - https://github.com/NVIDIA-NeMo/NeMo
|
||||||
|
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming.
|
||||||
|
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad.
|
||||||
|
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart.
|
||||||
|
|||||||
404
README.md
@@ -4,176 +4,118 @@
|
|||||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Diarization</b></p>
|
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Identification</b></p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads"></a>
|
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
||||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT-dark_green"></a>
|
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
## 🚀 Overview
|
|
||||||
|
|
||||||
This project is based on [WhisperStreaming](https://github.com/ufal/whisper_streaming) and [SimulStreaming](https://github.com/ufal/SimulStreaming), allowing you to transcribe audio directly from your browser. WhisperLiveKit provides a complete backend solution for real-time speech transcription with a functional, simple and customizable frontend. Everything runs locally on your machine ✨
|
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
|
||||||
|
|
||||||
### 🔄 Architecture
|
#### Powered by Leading Research:
|
||||||
|
|
||||||
WhisperLiveKit consists of three main components:
|
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||||
|
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
||||||
- **Frontend**: A basic html + JS interface that captures microphone audio and streams it to the backend via WebSockets. You can use and adapt the provided template at [whisperlivekit/web/live_transcription.html](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html).
|
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||||
- **Backend (Web Server)**: A FastAPI-based WebSocket server that receives streamed audio data, processes it in real time, and returns transcriptions to the frontend. This is where the WebSocket logic and routing live.
|
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||||
- **Core Backend (Library Logic)**: A server-agnostic core that handles audio processing, ASR, and diarization. It exposes reusable components that take in audio bytes and return transcriptions.
|
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||||
|
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
||||||
|
|
||||||
|
|
||||||
### ✨ Key Features
|
> **Why not just run a simple Whisper model on every audio batch?** Whisper is designed for complete utterances, not real-time chunks. Processing small segments loses context, cuts off words mid-syllable, and produces poor transcription. WhisperLiveKit uses state-of-the-art simultaneous speech research for intelligent buffering and incremental processing.
|
||||||
|
|
||||||
- **🎙️ Real-time Transcription** - Locally (or on-prem) convert speech to text instantly as you speak
|
|
||||||
- **👥 Speaker Diarization** - Identify different speakers in real-time using [Diart](https://github.com/juanmc2005/diart)
|
|
||||||
- **🌐 Multi-User Support** - Handle multiple users simultaneously with a single backend/server
|
|
||||||
- **🔇 Automatic Silence Chunking** – Automatically chunks when no audio is detected to limit buffer size
|
|
||||||
- **✅ Confidence Validation** – Immediately validate high-confidence tokens for faster inference (WhisperStreaming only)
|
|
||||||
- **👁️ Buffering Preview** – Displays unvalidated transcription segments (not compatible with SimulStreaming yet)
|
|
||||||
- **✒️ Punctuation-Based Speaker Splitting [BETA]** - Align speaker changes with natural sentence boundaries for more readable transcripts
|
|
||||||
- **⚡ SimulStreaming Backend** - Ultra-low latency transcription using state-of-the-art AlignAtt policy. The code is not directly included in the repo : To use, please copy [simul_whisper](https://github.com/ufal/SimulStreaming/tree/main/simul_whisper) content into `whisperlivekit/simul_whisper` . ⚠️ You must comply with the [Polyform license](https://github.com/ufal/SimulStreaming/blob/main/LICENCE.txt)
|
|
||||||
|
|
||||||
|
|
||||||
## 📖 Quick Start
|
### Architecture
|
||||||
|
|
||||||
```bash
|
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||||
# Install the package
|
|
||||||
pip install whisperlivekit
|
|
||||||
|
|
||||||
# Start the transcription server
|
*The backend supports multiple concurrent users. Voice Activity Detection reduces overhead when no voice is detected.*
|
||||||
whisperlivekit-server --model tiny.en
|
|
||||||
|
|
||||||
# Open your browser at http://localhost:8000 to see the interface.
|
### Installation & Quick Start
|
||||||
# Use -ssl-certfile public.crt --ssl-keyfile private.key parameters to use SSL
|
|
||||||
```
|
|
||||||
|
|
||||||
That's it! Start speaking and watch your words appear on screen.
|
|
||||||
|
|
||||||
## 🛠️ Installation Options
|
|
||||||
|
|
||||||
### Install from PyPI (Recommended)
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install whisperlivekit
|
pip install whisperlivekit
|
||||||
```
|
```
|
||||||
|
> You can also clone the repo and `pip install -e .` for the latest version.
|
||||||
|
|
||||||
### Install from Source
|
#### Quick Start
|
||||||
|
1. **Start the transcription server:**
|
||||||
|
```bash
|
||||||
|
whisperlivekit-server --model base --language en
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||||
|
|
||||||
|
|
||||||
|
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||||
|
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||||
|
|
||||||
|
#### Use it to capture audio from web pages.
|
||||||
|
|
||||||
|
Go to `chrome-extension` for instructions.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="600">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### Optional Dependencies
|
||||||
|
|
||||||
|
| Optional | `pip install` |
|
||||||
|
|-----------|-------------|
|
||||||
|
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||||
|
| **Apple Silicon optimizations** | `mlx-whisper` |
|
||||||
|
| **Translation** | `nllw` |
|
||||||
|
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
||||||
|
| *[Not recommanded]* Original Whisper backend | `whisper` |
|
||||||
|
| *[Not recommanded]* Improved timestamps backend | `whisper-timestamped` |
|
||||||
|
| OpenAI API backend | `openai` |
|
||||||
|
|
||||||
|
See **Parameters & Configuration** below on how to use them.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Usage Examples
|
||||||
|
|
||||||
|
**Command-line Interface**: Start the transcription server with various options:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/QuentinFuxa/WhisperLiveKit
|
# Large model and translate from french to danish
|
||||||
cd WhisperLiveKit
|
whisperlivekit-server --model large-v3 --language fr --target-language da
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
### System Dependencies
|
# Diarization and server listening on */80
|
||||||
|
whisperlivekit-server --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||||
FFmpeg is required:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Ubuntu/Debian
|
|
||||||
sudo apt install ffmpeg
|
|
||||||
|
|
||||||
# macOS
|
|
||||||
brew install ffmpeg
|
|
||||||
|
|
||||||
# Windows
|
|
||||||
# Download from https://ffmpeg.org/download.html and add to PATH
|
|
||||||
```
|
|
||||||
|
|
||||||
### Optional Dependencies
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Voice Activity Controller (prevents hallucinations)
|
|
||||||
pip install torch
|
|
||||||
|
|
||||||
# Sentence-based buffer trimming
|
|
||||||
pip install mosestokenizer wtpsplit
|
|
||||||
pip install tokenize_uk # If you work with Ukrainian text
|
|
||||||
|
|
||||||
# Speaker diarization
|
|
||||||
pip install diart
|
|
||||||
|
|
||||||
# Alternative Whisper backends (default is faster-whisper)
|
|
||||||
pip install whisperlivekit[whisper] # Original Whisper
|
|
||||||
pip install whisperlivekit[whisper-timestamped] # Improved timestamps
|
|
||||||
pip install whisperlivekit[mlx-whisper] # Apple Silicon optimization
|
|
||||||
pip install whisperlivekit[openai] # OpenAI API
|
|
||||||
pip install whisperlivekit[simulstreaming]
|
|
||||||
```
|
|
||||||
|
|
||||||
### 🎹 Pyannote Models Setup
|
|
||||||
|
|
||||||
For diarization, you need access to pyannote.audio models:
|
|
||||||
|
|
||||||
1. [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
|
|
||||||
2. [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model
|
|
||||||
3. [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
|
|
||||||
4. Login with HuggingFace:
|
|
||||||
```bash
|
|
||||||
pip install huggingface_hub
|
|
||||||
huggingface-cli login
|
|
||||||
```
|
|
||||||
|
|
||||||
## 💻 Usage Examples
|
|
||||||
|
|
||||||
### Command-line Interface
|
|
||||||
|
|
||||||
Start the transcription server with various options:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Basic server with English model
|
|
||||||
whisperlivekit-server --model tiny.en
|
|
||||||
|
|
||||||
# Advanced configuration with diarization
|
|
||||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language auto
|
|
||||||
|
|
||||||
# SimulStreaming backend for ultra-low latency
|
|
||||||
whisperlivekit-server --backend simulstreaming --model large-v3 --frame-threshold 20
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Python API Integration (Backend)
|
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
|
||||||
Check [basic_server.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a complete example.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# Global variable for the transcription engine
|
|
||||||
transcription_engine = None
|
transcription_engine = None
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global transcription_engine
|
global transcription_engine
|
||||||
# Example: Initialize with specific parameters directly
|
|
||||||
# You can also load from command-line arguments using parse_args()
|
|
||||||
# args = parse_args()
|
|
||||||
# transcription_engine = TranscriptionEngine(**vars(args))
|
|
||||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||||
yield
|
yield
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
# Serve the web interface
|
|
||||||
@app.get("/")
|
|
||||||
async def get():
|
|
||||||
return HTMLResponse(get_web_interface_html())
|
|
||||||
|
|
||||||
# Process WebSocket connections
|
|
||||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||||
try:
|
async for response in results_generator:
|
||||||
async for response in results_generator:
|
await websocket.send_json(response)
|
||||||
await websocket.send_json(response)
|
await websocket.send_json({"type": "ready_to_stop"})
|
||||||
await websocket.send_json({"type": "ready_to_stop"})
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
print("WebSocket disconnected during results handling.")
|
|
||||||
|
|
||||||
@app.websocket("/asr")
|
@app.websocket("/asr")
|
||||||
async def websocket_endpoint(websocket: WebSocket):
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
@@ -182,65 +124,54 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
# Create a new AudioProcessor for each connection, passing the shared engine
|
# Create a new AudioProcessor for each connection, passing the shared engine
|
||||||
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||||
results_generator = await audio_processor.create_tasks()
|
results_generator = await audio_processor.create_tasks()
|
||||||
send_results_to_client = handle_websocket_results(websocket, results_generator)
|
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||||
results_task = asyncio.create_task(send_results_to_client)
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
try:
|
while True:
|
||||||
while True:
|
message = await websocket.receive_bytes()
|
||||||
message = await websocket.receive_bytes()
|
await audio_processor.process_audio(message)
|
||||||
await audio_processor.process_audio(message)
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
print(f"Client disconnected: {websocket.client}")
|
|
||||||
except Exception as e:
|
|
||||||
await websocket.close(code=1011, reason=f"Server error: {e}")
|
|
||||||
finally:
|
|
||||||
results_task.cancel()
|
|
||||||
try:
|
|
||||||
await results_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("Results task successfully cancelled.")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Frontend Implementation
|
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()`
|
||||||
|
|
||||||
The package includes a simple HTML/JavaScript implementation that you can adapt for your project. You can find it in `whisperlivekit/web/live_transcription.html`, or load its content using the `get_web_interface_html()` function from `whisperlivekit`:
|
|
||||||
|
|
||||||
```python
|
## Parameters & Configuration
|
||||||
from whisperlivekit import get_web_interface_html
|
|
||||||
|
|
||||||
# ... later in your code where you need the HTML string ...
|
|
||||||
html_content = get_web_interface_html()
|
|
||||||
```
|
|
||||||
|
|
||||||
## ⚙️ Configuration Reference
|
|
||||||
|
|
||||||
WhisperLiveKit offers extensive configuration options:
|
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
|
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
|
||||||
|
| `--model-path` | .pt file/directory containing whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
||||||
|
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||||
|
| `--target-language` | If sets, translate to using NLLB. Ex: `fr`. [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you should rather use `--task translate`, since Whisper can do it directly. | `None` |
|
||||||
|
| `--task` | Set to `translate` to translate *only* to english, using Whisper translation. | `transcribe` |
|
||||||
|
| `--diarization` | Enable speaker identification | `False` |
|
||||||
|
| `--backend` | Processing backend. You can switch to `faster-whisper` if `simulstreaming` does not work correctly | `simulstreaming` |
|
||||||
|
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||||
|
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||||
|
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||||
| `--host` | Server host address | `localhost` |
|
| `--host` | Server host address | `localhost` |
|
||||||
| `--port` | Server port | `8000` |
|
| `--port` | Server port | `8000` |
|
||||||
| `--model` | Whisper model size. Caution : '.en' models do not work with Simulstreaming | `tiny` |
|
|
||||||
| `--language` | Source language code or `auto` | `en` |
|
|
||||||
| `--task` | `transcribe` or `translate` | `transcribe` |
|
|
||||||
| `--backend` | Processing backend | `faster-whisper` |
|
|
||||||
| `--diarization` | Enable speaker identification | `False` |
|
|
||||||
| `--punctuation-split` | Use punctuation to improve speaker boundaries | `True` |
|
|
||||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
|
||||||
| `--min-chunk-size` | Minimum audio chunk size (seconds) | `1.0` |
|
|
||||||
| `--vac` | Use Voice Activity Controller | `False` |
|
|
||||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
|
||||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
|
||||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
|
||||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||||
| `--segmentation-model` | Hugging Face model ID for pyannote.audio segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
|
||||||
| `--embedding-model` | Hugging Face model ID for pyannote.audio embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
|
||||||
|
|
||||||
**SimulStreaming-specific Options:**
|
| Translation options | Description | Default |
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
|
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
|
||||||
|
| `--nllb-size` | `600M` or `1.3B` | `600M` |
|
||||||
|
|
||||||
|
| Diarization options | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||||
|
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
|
||||||
|
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||||
|
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||||
|
|
||||||
|
| SimulStreaming backend options | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||||
|
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used | `None` |
|
||||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||||
@@ -251,116 +182,87 @@ WhisperLiveKit offers extensive configuration options:
|
|||||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||||
| `--model-path` | Direct path to .pt model file. Download it if not found | `./base.pt` |
|
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||||
|
|
||||||
## 🔧 How It Works
|
|
||||||
|
|
||||||
1. **Audio Capture**: Browser's MediaRecorder API captures audio in webm/opus format
|
|
||||||
2. **Streaming**: Audio chunks are sent to the server via WebSocket
|
|
||||||
3. **Processing**: Server decodes audio with FFmpeg and streams into Whisper for transcription
|
|
||||||
4. **Real-time Output**:
|
|
||||||
- Partial transcriptions appear immediately in light gray (the 'aperçu')
|
|
||||||
- Finalized text appears in normal color
|
|
||||||
- (When enabled) Different speakers are identified and highlighted
|
|
||||||
|
|
||||||
## 🚀 Deployment Guide
|
| WhisperStreaming backend options | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||||
|
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
> For diarization using Diart, you need to accept user conditions [here](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model, [here](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model and [here](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model. **Then**, login to HuggingFace: `huggingface-cli login`
|
||||||
|
|
||||||
|
### 🚀 Deployment Guide
|
||||||
|
|
||||||
To deploy WhisperLiveKit in production:
|
To deploy WhisperLiveKit in production:
|
||||||
|
|
||||||
1. **Server Setup** (Backend):
|
1. **Server Setup**: Install production ASGI server & launch with multiple workers
|
||||||
```bash
|
```bash
|
||||||
# Install production ASGI server
|
|
||||||
pip install uvicorn gunicorn
|
pip install uvicorn gunicorn
|
||||||
|
|
||||||
# Launch with multiple workers
|
|
||||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Frontend Integration**:
|
2. **Frontend**: Host your customized version of the `html` example & ensure WebSocket connection points correctly
|
||||||
- Host your customized version of the example HTML/JS in your web application
|
|
||||||
- Ensure WebSocket connection points to your server's address
|
|
||||||
|
|
||||||
3. **Nginx Configuration** (recommended for production):
|
3. **Nginx Configuration** (recommended for production):
|
||||||
```nginx
|
```nginx
|
||||||
server {
|
server {
|
||||||
listen 80;
|
listen 80;
|
||||||
server_name your-domain.com;
|
server_name your-domain.com;
|
||||||
|
location / {
|
||||||
location / {
|
proxy_pass http://localhost:8000;
|
||||||
proxy_pass http://localhost:8000;
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
proxy_set_header Upgrade $http_upgrade;
|
proxy_set_header Connection "upgrade";
|
||||||
proxy_set_header Connection "upgrade";
|
proxy_set_header Host $host;
|
||||||
proxy_set_header Host $host;
|
}}
|
||||||
}
|
```
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
||||||
|
|
||||||
### 🐋 Docker
|
## 🐋 Docker
|
||||||
|
|
||||||
A basic Dockerfile is provided which allows re-use of Python package installation options. See below usage examples:
|
Deploy the application easily using Docker with GPU or CPU support.
|
||||||
|
|
||||||
**NOTE:** For **larger** models, ensure that your **docker runtime** has enough **memory** available.
|
### Prerequisites
|
||||||
|
- Docker installed on your system
|
||||||
|
- For GPU support: NVIDIA Docker runtime installed
|
||||||
|
|
||||||
#### All defaults
|
### Quick Start
|
||||||
- Create a reusable image with only the basics and then run as a named container:
|
|
||||||
|
**With GPU acceleration (recommended):**
|
||||||
```bash
|
```bash
|
||||||
docker build -t whisperlivekit-defaults .
|
docker build -t wlk .
|
||||||
docker create --gpus all --name whisperlivekit -p 8000:8000 whisperlivekit-defaults
|
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||||
docker start -i whisperlivekit
|
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note**: If you're running on a system without NVIDIA GPU support (such as Mac with Apple Silicon or any system without CUDA capabilities), you need to **remove the `--gpus all` flag** from the `docker create` command. Without GPU acceleration, transcription will use CPU only, which may be significantly slower. Consider using small models for better performance on CPU-only systems.
|
**CPU only:**
|
||||||
|
```bash
|
||||||
|
docker build -f Dockerfile.cpu -t wlk .
|
||||||
|
docker run -p 8000:8000 --name wlk wlk
|
||||||
|
```
|
||||||
|
|
||||||
|
### Advanced Usage
|
||||||
|
|
||||||
|
**Custom configuration:**
|
||||||
|
```bash
|
||||||
|
# Example with custom model and language
|
||||||
|
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||||
|
```
|
||||||
|
|
||||||
|
### Memory Requirements
|
||||||
|
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
|
||||||
|
|
||||||
|
|
||||||
#### Customization
|
#### Customization
|
||||||
- Customize the container options:
|
|
||||||
```bash
|
|
||||||
docker build -t whisperlivekit-defaults .
|
|
||||||
docker create --gpus all --name whisperlivekit-base -p 8000:8000 whisperlivekit-defaults --model base
|
|
||||||
docker start -i whisperlivekit-base
|
|
||||||
```
|
|
||||||
|
|
||||||
- `--build-arg` Options:
|
- `--build-arg` Options:
|
||||||
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
||||||
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
||||||
- `HF_TOKEN="./token"` - Add your Hugging Face Hub access token to download gated models
|
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
|
||||||
|
|
||||||
## 🔮 Use Cases
|
## 🔮 Use Cases
|
||||||
|
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
|
||||||
- **Meeting Transcription**: Capture discussions in real-time
|
|
||||||
- **Accessibility Tools**: Help hearing-impaired users follow conversations
|
|
||||||
- **Content Creation**: Transcribe podcasts or videos automatically
|
|
||||||
- **Customer Service**: Transcribe support calls with speaker identification
|
|
||||||
|
|
||||||
## 📄 License
|
|
||||||
|
|
||||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
||||||
|
|
||||||
**⚠️ Important**: When using the SimulStreaming backend, you must also comply with the **PolyForm Noncommercial License 1.0.0** that governs SimulStreaming. For commercial use of the SimulStreaming backend, obtain a commercial license from the [SimulStreaming authors](https://github.com/ufal/SimulStreaming#-licence-and-contributions).
|
|
||||||
|
|
||||||
## 🤝 Contributing
|
|
||||||
|
|
||||||
Contributions are welcome! Here's how to get started:
|
|
||||||
|
|
||||||
1. Fork the repository
|
|
||||||
2. Create a feature branch: `git checkout -b feature/amazing-feature`
|
|
||||||
3. Commit your changes: `git commit -m 'Add amazing feature'`
|
|
||||||
4. Push to your branch: `git push origin feature/amazing-feature`
|
|
||||||
5. Open a Pull Request
|
|
||||||
|
|
||||||
## 🙏 Acknowledgments
|
|
||||||
|
|
||||||
This project builds upon the foundational work of:
|
|
||||||
- [Whisper Streaming](https://github.com/ufal/whisper_streaming)
|
|
||||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (BETA backend)
|
|
||||||
- [Diart](https://github.com/juanmc2005/diart)
|
|
||||||
- [OpenAI Whisper](https://github.com/openai/whisper)
|
|
||||||
|
|
||||||
We extend our gratitude to the original authors for their contributions.
|
|
||||||
|
|
||||||
## 🔗 Links
|
|
||||||
|
|
||||||
- [GitHub Repository](https://github.com/QuentinFuxa/WhisperLiveKit)
|
|
||||||
- [PyPI Package](https://pypi.org/project/whisperlivekit/)
|
|
||||||
- [Issue Tracker](https://github.com/QuentinFuxa/WhisperLiveKit/issues)
|
|
||||||
|
|||||||
258
ReadmeJP.md
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
<h1 align="center">WhisperLiveKit</h1>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||||
|
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||||
|
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||||
|
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
|
||||||
|
|
||||||
|
#### 主要な研究による技術:
|
||||||
|
|
||||||
|
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
|
||||||
|
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
|
||||||
|
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
|
||||||
|
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
|
||||||
|
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
|
||||||
|
|
||||||
|
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか?** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
|
||||||
|
|
||||||
|
### アーキテクチャ
|
||||||
|
|
||||||
|
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||||
|
|
||||||
|
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
|
||||||
|
|
||||||
|
### インストールとクイックスタート
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install whisperlivekit
|
||||||
|
```
|
||||||
|
|
||||||
|
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
|
||||||
|
>
|
||||||
|
> | OS | インストール方法 |
|
||||||
|
> |-----------|-------------|
|
||||||
|
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
||||||
|
> | MacOS | `brew install ffmpeg` |
|
||||||
|
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
|
||||||
|
|
||||||
|
#### クイックスタート
|
||||||
|
1. **文字起こしサーバーを起動します:**
|
||||||
|
```bash
|
||||||
|
whisperlivekit-server --model base --language en
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
|
||||||
|
|
||||||
|
|
||||||
|
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
|
||||||
|
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
|
||||||
|
|
||||||
|
#### オプションの依存関係
|
||||||
|
|
||||||
|
| オプション | `pip install` |
|
||||||
|
|-----------|-------------|
|
||||||
|
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||||
|
| Diartによる話者ダイアライゼーション | `diart` |
|
||||||
|
| オリジナルのWhisperバックエンド | `whisper` |
|
||||||
|
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
|
||||||
|
| Apple Silicon最適化バックエンド | `mlx-whisper` |
|
||||||
|
| OpenAI APIバックエンド | `openai` |
|
||||||
|
|
||||||
|
それらの使用方法については、以下の**パラメータと設定**を参照してください。
|
||||||
|
|
||||||
|
### 使用例
|
||||||
|
|
||||||
|
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# デフォルト(small)より良いモデルを使用
|
||||||
|
whisperlivekit-server --model large-v3
|
||||||
|
|
||||||
|
# ダイアライゼーションと言語を指定した高度な設定
|
||||||
|
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||||
|
```
|
||||||
|
|
||||||
|
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
transcription_engine = None
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
global transcription_engine
|
||||||
|
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||||
|
yield
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||||
|
async for response in results_generator:
|
||||||
|
await websocket.send_json(response)
|
||||||
|
await websocket.send_json({"type": "ready_to_stop"})
|
||||||
|
|
||||||
|
@app.websocket("/asr")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
global transcription_engine
|
||||||
|
|
||||||
|
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
|
||||||
|
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||||
|
results_generator = await audio_processor.create_tasks()
|
||||||
|
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||||
|
await websocket.accept()
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive_bytes()
|
||||||
|
await audio_processor.process_audio(message)
|
||||||
|
```
|
||||||
|
|
||||||
|
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
|
||||||
|
|
||||||
|
|
||||||
|
## パラメータと設定
|
||||||
|
|
||||||
|
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
|
||||||
|
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
||||||
|
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
|
||||||
|
- `--backend`? `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
|
||||||
|
- `--warmup-file`、もしあれば
|
||||||
|
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
|
||||||
|
- `--diarization`、使用したい場合。
|
||||||
|
|
||||||
|
残りは推奨しません。しかし、以下があなたのオプションです。
|
||||||
|
|
||||||
|
| パラメータ | 説明 | デフォルト |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--model` | Whisperモデルのサイズ。 | `small` |
|
||||||
|
| `--language` | ソース言語コードまたは`auto` | `auto` |
|
||||||
|
| `--task` | `transcribe`または`translate` | `transcribe` |
|
||||||
|
| `--backend` | 処理バックエンド | `simulstreaming` |
|
||||||
|
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
|
||||||
|
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
|
||||||
|
| `--no-vad` | 音声区間検出を無効化 | `False` |
|
||||||
|
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
|
||||||
|
| `--host` | サーバーホストアドレス | `localhost` |
|
||||||
|
| `--port` | サーバーポート | `8000` |
|
||||||
|
| `--ssl-certfile` | SSL証明書ファイルへのパス(HTTPSサポート用) | `None` |
|
||||||
|
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパス(HTTPSサポート用) | `None` |
|
||||||
|
|
||||||
|
|
||||||
|
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
|
||||||
|
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment`) | `segment` |
|
||||||
|
|
||||||
|
|
||||||
|
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--frame-threshold` | AlignAttフレームしきい値(低いほど速く、高いほど正確) | `25` |
|
||||||
|
| `--beams` | ビームサーチのビーム数(1 = 貪欲デコーディング) | `1` |
|
||||||
|
| `--decoder` | デコーダタイプを強制(`beam`または`greedy`) | `auto` |
|
||||||
|
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
|
||||||
|
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
|
||||||
|
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
|
||||||
|
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
|
||||||
|
| `--init-prompt` | モデルの初期プロンプト | `None` |
|
||||||
|
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
|
||||||
|
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
|
||||||
|
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
|
||||||
|
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
|
||||||
|
|
||||||
|
| ダイアライゼーションオプション | 説明 | デフォルト |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--diarization` | 話者識別を有効化 | `False` |
|
||||||
|
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
|
||||||
|
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||||
|
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||||
|
|
||||||
|
|
||||||
|
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です:
|
||||||
|
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
|
||||||
|
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
|
||||||
|
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
|
||||||
|
>4. HuggingFaceでログイン: `huggingface-cli login`
|
||||||
|
|
||||||
|
### 🚀 デプロイガイド
|
||||||
|
|
||||||
|
WhisperLiveKitを本番環境にデプロイするには:
|
||||||
|
|
||||||
|
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
|
||||||
|
```bash
|
||||||
|
pip install uvicorn gunicorn
|
||||||
|
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
|
||||||
|
|
||||||
|
3. **Nginx設定** (本番環境で推奨):
|
||||||
|
```nginx
|
||||||
|
server {
|
||||||
|
listen 80;
|
||||||
|
server_name your-domain.com;
|
||||||
|
location / {
|
||||||
|
proxy_pass http://localhost:8000;
|
||||||
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
|
proxy_set_header Connection "upgrade";
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
|
||||||
|
|
||||||
|
## 🐋 Docker
|
||||||
|
|
||||||
|
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
|
||||||
|
|
||||||
|
### 前提条件
|
||||||
|
- Dockerがシステムにインストールされていること
|
||||||
|
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
|
||||||
|
|
||||||
|
### クイックスタート
|
||||||
|
|
||||||
|
**GPUアクセラレーション付き (推奨):**
|
||||||
|
```bash
|
||||||
|
docker build -t wlk .
|
||||||
|
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||||
|
```
|
||||||
|
|
||||||
|
**CPUのみ:**
|
||||||
|
```bash
|
||||||
|
docker build -f Dockerfile.cpu -t wlk .
|
||||||
|
docker run -p 8000:8000 --name wlk wlk
|
||||||
|
```
|
||||||
|
|
||||||
|
### 高度な使用法
|
||||||
|
|
||||||
|
**カスタム設定:**
|
||||||
|
```bash
|
||||||
|
# カスタムモデルと言語の例
|
||||||
|
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||||
|
```
|
||||||
|
|
||||||
|
### メモリ要件
|
||||||
|
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
|
||||||
|
|
||||||
|
|
||||||
|
#### カスタマイズ
|
||||||
|
|
||||||
|
- `--build-arg` オプション:
|
||||||
|
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
|
||||||
|
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
|
||||||
|
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
|
||||||
|
|
||||||
|
## 🔮 ユースケース
|
||||||
|
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...
|
||||||
BIN
architecture.png
Normal file
|
After Width: | Height: | Size: 406 KiB |
19
chrome-extension/README.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
## WhisperLiveKit Chrome Extension v0.1.1
|
||||||
|
Capture the audio of your current tab, transcribe diarize and translate it using WhisperliveKit, in Chrome and other Chromium-based browsers.
|
||||||
|
|
||||||
|
> Currently, only the tab audio is captured; your microphone audio is not recorded.
|
||||||
|
|
||||||
|
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
||||||
|
|
||||||
|
## Running this extension
|
||||||
|
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
||||||
|
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
||||||
|
|
||||||
|
|
||||||
|
## Devs:
|
||||||
|
- Impossible to capture audio from tabs if extension is a pannel, unfortunately:
|
||||||
|
- https://issues.chromium.org/issues/40926394
|
||||||
|
- https://groups.google.com/a/chromium.org/g/chromium-extensions/c/DET2SXCFnDg
|
||||||
|
- https://issues.chromium.org/issues/40916430
|
||||||
|
|
||||||
|
- To capture microphone in an extension, there are tricks: https://github.com/justinmann/sidepanel-audio-issue , https://medium.com/@lynchee.owo/how-to-enable-microphone-access-in-chrome-extensions-by-code-924295170080 (comments)
|
||||||
9
chrome-extension/background.js
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
chrome.runtime.onInstalled.addListener((details) => {
|
||||||
|
if (details.reason.search(/install/g) === -1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
chrome.tabs.create({
|
||||||
|
url: chrome.runtime.getURL("welcome.html"),
|
||||||
|
active: true
|
||||||
|
})
|
||||||
|
})
|
||||||
BIN
chrome-extension/demo-extension.png
Normal file
|
After Width: | Height: | Size: 5.8 MiB |
BIN
chrome-extension/icons/icon128.png
Normal file
|
After Width: | Height: | Size: 5.8 KiB |
BIN
chrome-extension/icons/icon16.png
Normal file
|
After Width: | Height: | Size: 376 B |
BIN
chrome-extension/icons/icon32.png
Normal file
|
After Width: | Height: | Size: 823 B |
BIN
chrome-extension/icons/icon48.png
Normal file
|
After Width: | Height: | Size: 1.4 KiB |
23
chrome-extension/manifest.json
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"manifest_version": 3,
|
||||||
|
"name": "WhisperLiveKit Tab Capture",
|
||||||
|
"version": "1.0",
|
||||||
|
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
|
||||||
|
"icons": {
|
||||||
|
"16": "icons/icon16.png",
|
||||||
|
"32": "icons/icon32.png",
|
||||||
|
"48": "icons/icon48.png",
|
||||||
|
"128": "icons/icon128.png"
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"default_title": "WhisperLiveKit Tab Capture",
|
||||||
|
"default_popup": "live_transcription.html"
|
||||||
|
},
|
||||||
|
"permissions": [
|
||||||
|
"scripting",
|
||||||
|
"tabCapture",
|
||||||
|
"offscreen",
|
||||||
|
"activeTab",
|
||||||
|
"storage"
|
||||||
|
]
|
||||||
|
}
|
||||||
12
chrome-extension/requestPermissions.html
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Request Permissions</title>
|
||||||
|
<script src="requestPermissions.js"></script>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
This page exists to workaround an issue with Chrome that blocks permission
|
||||||
|
requests from chrome extensions
|
||||||
|
<button id="requestMicrophone">Request Microphone</button>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
17
chrome-extension/requestPermissions.js
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
/**
|
||||||
|
* Requests user permission for microphone access.
|
||||||
|
* @returns {Promise<void>} A Promise that resolves when permission is granted or rejects with an error.
|
||||||
|
*/
|
||||||
|
async function getUserPermission() {
|
||||||
|
console.log("Getting user permission for microphone access...");
|
||||||
|
await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||||
|
const micPermission = await navigator.permissions.query({
|
||||||
|
name: "microphone",
|
||||||
|
});
|
||||||
|
if (micPermission.state == "granted") {
|
||||||
|
window.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the function to request microphone permission
|
||||||
|
getUserPermission();
|
||||||
29
chrome-extension/sidepanel.js
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
console.log("sidepanel.js");
|
||||||
|
|
||||||
|
async function run() {
|
||||||
|
const micPermission = await navigator.permissions.query({
|
||||||
|
name: "microphone",
|
||||||
|
});
|
||||||
|
|
||||||
|
document.getElementById(
|
||||||
|
"audioPermission"
|
||||||
|
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||||
|
|
||||||
|
if (micPermission.state !== "granted") {
|
||||||
|
chrome.tabs.create({ url: "requestPermissions.html" });
|
||||||
|
}
|
||||||
|
|
||||||
|
const intervalId = setInterval(async () => {
|
||||||
|
const micPermission = await navigator.permissions.query({
|
||||||
|
name: "microphone",
|
||||||
|
});
|
||||||
|
if (micPermission.state === "granted") {
|
||||||
|
document.getElementById(
|
||||||
|
"audioPermission"
|
||||||
|
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||||
|
clearInterval(intervalId);
|
||||||
|
}
|
||||||
|
}, 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
void run();
|
||||||
BIN
demo.png
|
Before Width: | Height: | Size: 438 KiB After Width: | Height: | Size: 985 KiB |
264
docs/API.md
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
# WhisperLiveKit WebSocket API Documentation
|
||||||
|
|
||||||
|
> !! **Note**: The new API structure described in this document is currently under deployment.
|
||||||
|
This documentation is intended for devs who want to build custom frontends.
|
||||||
|
|
||||||
|
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Legacy API (Current)
|
||||||
|
|
||||||
|
### Message Structure
|
||||||
|
|
||||||
|
The current API sends complete state snapshots on each update (several time per second)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
{
|
||||||
|
"type": str,
|
||||||
|
"status": str,
|
||||||
|
"lines": [
|
||||||
|
{
|
||||||
|
"speaker": int,
|
||||||
|
"text": str,
|
||||||
|
"start": float,
|
||||||
|
"end": float,
|
||||||
|
"translation": str | null,
|
||||||
|
"detected_language": str
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"buffer_transcription": str,
|
||||||
|
"buffer_diarization": str,
|
||||||
|
"remaining_time_transcription": float,
|
||||||
|
"remaining_time_diarization": float
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## New API (Under Development)
|
||||||
|
|
||||||
|
### Philosophy
|
||||||
|
|
||||||
|
Principles:
|
||||||
|
|
||||||
|
- **Incremental Updates**: Only updates and new segments are sent
|
||||||
|
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
|
||||||
|
|
||||||
|
|
||||||
|
## Message Format
|
||||||
|
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
{
|
||||||
|
"type": "transcript_update",
|
||||||
|
"status": "active_transcription" | "no_audio_detected",
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"id": number,
|
||||||
|
"speaker": number,
|
||||||
|
"text": string,
|
||||||
|
"start_speaker": float,
|
||||||
|
"start": float,
|
||||||
|
"end": float,
|
||||||
|
"language": string | null,
|
||||||
|
"translation": string,
|
||||||
|
"words": [
|
||||||
|
{
|
||||||
|
"text": string,
|
||||||
|
"start": float,
|
||||||
|
"end": float,
|
||||||
|
"validated": {
|
||||||
|
"text": boolean,
|
||||||
|
"speaker": boolean,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"buffer": {
|
||||||
|
"transcription": string,
|
||||||
|
"diarization": string,
|
||||||
|
"translation": string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"remaining_time_transcription": float,
|
||||||
|
"remaining_time_diarization": float
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Other Message Types
|
||||||
|
|
||||||
|
#### Config Message (sent on connection)
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "config",
|
||||||
|
"useAudioWorklet": true / false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Ready to Stop Message (sent after processing complete)
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "ready_to_stop"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Field Descriptions
|
||||||
|
|
||||||
|
### Segment Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
|
||||||
|
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
|
||||||
|
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
|
||||||
|
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
|
||||||
|
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
|
||||||
|
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
|
||||||
|
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
|
||||||
|
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
|
||||||
|
| `words` | `Array` | Array of word-level objects with timing and validation information. |
|
||||||
|
| `buffer` | `Object` | Per-segment temporary buffers, see below |
|
||||||
|
|
||||||
|
### Word Object
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `text` | `string` | The word text. |
|
||||||
|
| `start` | `number` | Start timestamp (seconds) of this word. |
|
||||||
|
| `end` | `number` | End timestamp (seconds) of this word. |
|
||||||
|
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
|
||||||
|
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
|
||||||
|
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
|
||||||
|
|
||||||
|
### Buffer Object (Per-Segment)
|
||||||
|
|
||||||
|
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
|
||||||
|
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
|
||||||
|
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
|
||||||
|
|
||||||
|
|
||||||
|
### Metadata Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
|
||||||
|
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
|
||||||
|
|
||||||
|
### Status Values
|
||||||
|
|
||||||
|
| Status | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| `active_transcription` | Normal operation, transcription is active. |
|
||||||
|
| `no_audio_detected` | No audio has been detected yet. |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Update Behavior
|
||||||
|
|
||||||
|
### Incremental Updates
|
||||||
|
|
||||||
|
The API sends **only changed or new segments**. Clients should:
|
||||||
|
|
||||||
|
1. Maintain a local map of segments by ID
|
||||||
|
2. When receiving an update, merge/update segments by ID
|
||||||
|
3. Render only the changed segments
|
||||||
|
|
||||||
|
### Language Detection
|
||||||
|
|
||||||
|
When language is detected for a segment:
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
// Update 1: No language yet
|
||||||
|
{
|
||||||
|
"segments": [
|
||||||
|
{"id": 1, "speaker": 1, "text": "May see", "language": null}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 2: Same segment ID, language now detected
|
||||||
|
{
|
||||||
|
"segments": [
|
||||||
|
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Client behavior**: **Replace** the existing segment with the same ID.
|
||||||
|
|
||||||
|
### Buffer Behavior
|
||||||
|
|
||||||
|
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
|
||||||
|
|
||||||
|
#### Example: Translation with diarization and translation
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
// Update 1
|
||||||
|
{
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"speaker": 1,
|
||||||
|
"text": "Hello world, how are",
|
||||||
|
"translation": "",
|
||||||
|
"buffer": {
|
||||||
|
"transcription": "",
|
||||||
|
"diarization": " you on",
|
||||||
|
"translation": "Bonjour le monde"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ==== Frontend ====
|
||||||
|
// <SPEAKER>1</SPEAKER>
|
||||||
|
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||||
|
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
|
||||||
|
|
||||||
|
|
||||||
|
// Update 2
|
||||||
|
{
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"speaker": 1,
|
||||||
|
"text": " you on this",
|
||||||
|
"translation": "Bonjour tout le monde",
|
||||||
|
"buffer": {
|
||||||
|
"transcription": "",
|
||||||
|
"diarization": " beautiful day",
|
||||||
|
"translation": ",comment"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ==== Frontend ====
|
||||||
|
// <SPEAKER>1</SPEAKER>
|
||||||
|
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||||
|
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Silence Segments
|
||||||
|
|
||||||
|
Silence is represented with the speaker id = `-2`:
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"speaker": -2,
|
||||||
|
"text": "",
|
||||||
|
"start": 10.5,
|
||||||
|
"end": 12.3
|
||||||
|
}
|
||||||
|
```
|
||||||
109
docs/available_models.md
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# Available Whisper model sizes:
|
||||||
|
|
||||||
|
- tiny.en (english only)
|
||||||
|
- tiny
|
||||||
|
- base.en (english only)
|
||||||
|
- base
|
||||||
|
- small.en (english only)
|
||||||
|
- small
|
||||||
|
- medium.en (english only)
|
||||||
|
- medium
|
||||||
|
- large-v1
|
||||||
|
- large-v2
|
||||||
|
- large-v3
|
||||||
|
- large-v3-turbo
|
||||||
|
|
||||||
|
## How to choose?
|
||||||
|
|
||||||
|
### Language Support
|
||||||
|
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
|
||||||
|
- **Multilingual**: Do not use `.en` models.
|
||||||
|
|
||||||
|
### Resource Constraints
|
||||||
|
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
|
||||||
|
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
|
||||||
|
- `base`: Good balance of speed and accuracy for basic use cases
|
||||||
|
- `small`: Better accuracy while still being resource-efficient
|
||||||
|
- **Good resources available**: Use `large` models for best accuracy
|
||||||
|
- `large-v2`: Excellent accuracy, good multilingual support
|
||||||
|
- `large-v3`: Best overall accuracy and language support
|
||||||
|
|
||||||
|
### Special Cases
|
||||||
|
- **No translation needed**: Use `large-v3-turbo`
|
||||||
|
- Same transcription quality as `large-v2` but significantly faster
|
||||||
|
- **Important**: Does not translate correctly, only transcribes
|
||||||
|
|
||||||
|
### Model Comparison Table
|
||||||
|
|
||||||
|
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|
||||||
|
|-------|--------|----------|--------------|-------------|---------------|
|
||||||
|
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
|
||||||
|
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
|
||||||
|
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
|
||||||
|
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
|
||||||
|
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
|
||||||
|
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
|
||||||
|
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
|
||||||
|
|
||||||
|
### Additional Considerations
|
||||||
|
|
||||||
|
**Model Performance**:
|
||||||
|
- Accuracy improves significantly from tiny to large models
|
||||||
|
- English-only models are ~10-15% more accurate for English audio
|
||||||
|
- Newer versions (v2, v3) have better punctuation and formatting
|
||||||
|
|
||||||
|
**Hardware Requirements**:
|
||||||
|
- `tiny`: ~1GB VRAM
|
||||||
|
- `base`: ~1GB VRAM
|
||||||
|
- `small`: ~2GB VRAM
|
||||||
|
- `medium`: ~5GB VRAM
|
||||||
|
- `large`: ~10GB VRAM
|
||||||
|
- `large‑v3‑turbo`: ~6GB VRAM
|
||||||
|
|
||||||
|
**Audio Quality Impact**:
|
||||||
|
- Clean, clear audio: smaller models may suffice
|
||||||
|
- Noisy, accented, or technical audio: larger models recommended
|
||||||
|
- Phone/low-quality audio: use at least `small` model
|
||||||
|
|
||||||
|
### Quick Decision Tree
|
||||||
|
1. English only? → Add `.en` to your choice
|
||||||
|
2. Limited resources or need speed? → `small` or smaller
|
||||||
|
3. Good hardware and want best quality? → `large-v3`
|
||||||
|
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
|
||||||
|
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
||||||
|
|
||||||
|
|
||||||
|
_______________________
|
||||||
|
|
||||||
|
# Translation Models and Backend
|
||||||
|
|
||||||
|
**Language Support**: ~200 languages
|
||||||
|
|
||||||
|
## Distilled Model Sizes Available
|
||||||
|
|
||||||
|
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|
||||||
|
|-------|------|------------|-------------|-------------|---------|
|
||||||
|
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
|
||||||
|
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
|
||||||
|
|
||||||
|
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
|
||||||
|
|
||||||
|
## Backend Performance
|
||||||
|
|
||||||
|
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|
||||||
|
|---------|---------------|--------------|--------------|
|
||||||
|
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
|
||||||
|
| Transformers | Baseline | High | None |
|
||||||
|
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
|
||||||
|
|
||||||
|
**Metrics**:
|
||||||
|
- CTranslate2: 50-100+ tokens/sec
|
||||||
|
- Transformers: 10-30 tokens/sec
|
||||||
|
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
|
||||||
|
|
||||||
|
## Quick Decision Matrix
|
||||||
|
|
||||||
|
**Choose 600M**: Limited resources, close to 0 lag
|
||||||
|
**Choose 1.3B**: Quality matters
|
||||||
|
**Choose Transformers**: On Apple Silicon
|
||||||
|
|
||||||
14
docs/models_compatible_formats.md
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# Model Path Formats
|
||||||
|
|
||||||
|
The `--model-path` parameter accepts:
|
||||||
|
|
||||||
|
## File Path
|
||||||
|
- **`.pt` format only** (required for AlignAtt policy decoder)
|
||||||
|
|
||||||
|
## Directory Path (recommended)
|
||||||
|
Must contain:
|
||||||
|
- **`.pt` file** (required for decoder)
|
||||||
|
|
||||||
|
May optionally contain:
|
||||||
|
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
||||||
|
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||||
265
docs/supported_languages.md
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
# Supported Languages
|
||||||
|
|
||||||
|
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
|
||||||
|
|
||||||
|
## How to Specify Languages
|
||||||
|
|
||||||
|
You can specify languages in **three different ways**:
|
||||||
|
|
||||||
|
1. **Language Name** (case-insensitive): `"English"`, `"French"`, `"Spanish"`
|
||||||
|
2. **ISO Language Code**: `"en"`, `"fr"`, `"es"`
|
||||||
|
3. **NLLB Code** (FLORES-200): `"eng_Latn"`, `"fra_Latn"`, `"spa_Latn"`
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Command Line
|
||||||
|
```bash
|
||||||
|
# Using language name
|
||||||
|
whisperlivekit-server --target-language "French"
|
||||||
|
|
||||||
|
# Using ISO code
|
||||||
|
whisperlivekit-server --target-language fr
|
||||||
|
|
||||||
|
# Using NLLB code
|
||||||
|
whisperlivekit-server --target-language fra_Latn
|
||||||
|
```
|
||||||
|
|
||||||
|
### Python API
|
||||||
|
```python
|
||||||
|
from nllw.translation import get_language_info
|
||||||
|
|
||||||
|
# Get language information by name
|
||||||
|
lang_info = get_language_info("French")
|
||||||
|
print(lang_info)
|
||||||
|
# {'name': 'French', 'nllb': 'fra_Latn', 'language_code': 'fr'}
|
||||||
|
|
||||||
|
# Get language information by ISO code
|
||||||
|
lang_info = get_language_info("fr")
|
||||||
|
|
||||||
|
# Get language information by NLLB code
|
||||||
|
lang_info = get_language_info("fra_Latn")
|
||||||
|
|
||||||
|
# All three return the same result
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Language List
|
||||||
|
|
||||||
|
The following table lists all 201 supported languages with their corresponding codes:
|
||||||
|
|
||||||
|
| Language Name | ISO Code | NLLB Code |
|
||||||
|
|---------------|----------|-----------|
|
||||||
|
| Acehnese (Arabic script) | ace_Arab | ace_Arab |
|
||||||
|
| Acehnese (Latin script) | ace_Latn | ace_Latn |
|
||||||
|
| Mesopotamian Arabic | acm_Arab | acm_Arab |
|
||||||
|
| Ta'izzi-Adeni Arabic | acq_Arab | acq_Arab |
|
||||||
|
| Tunisian Arabic | aeb_Arab | aeb_Arab |
|
||||||
|
| Afrikaans | af | afr_Latn |
|
||||||
|
| South Levantine Arabic | ajp_Arab | ajp_Arab |
|
||||||
|
| Akan | ak | aka_Latn |
|
||||||
|
| Tosk Albanian | als | als_Latn |
|
||||||
|
| Amharic | am | amh_Ethi |
|
||||||
|
| North Levantine Arabic | apc_Arab | apc_Arab |
|
||||||
|
| Modern Standard Arabic | ar | arb_Arab |
|
||||||
|
| Modern Standard Arabic (Romanized) | arb_Latn | arb_Latn |
|
||||||
|
| Najdi Arabic | ars_Arab | ars_Arab |
|
||||||
|
| Moroccan Arabic | ary_Arab | ary_Arab |
|
||||||
|
| Egyptian Arabic | arz_Arab | arz_Arab |
|
||||||
|
| Assamese | as | asm_Beng |
|
||||||
|
| Asturian | ast | ast_Latn |
|
||||||
|
| Awadhi | awa | awa_Deva |
|
||||||
|
| Central Aymara | ay | ayr_Latn |
|
||||||
|
| South Azerbaijani | azb | azb_Arab |
|
||||||
|
| North Azerbaijani | az | azj_Latn |
|
||||||
|
| Bashkir | ba | bak_Cyrl |
|
||||||
|
| Bambara | bm | bam_Latn |
|
||||||
|
| Balinese | ban | ban_Latn |
|
||||||
|
| Belarusian | be | bel_Cyrl |
|
||||||
|
| Bemba | bem | bem_Latn |
|
||||||
|
| Bengali | bn | ben_Beng |
|
||||||
|
| Bhojpuri | bho | bho_Deva |
|
||||||
|
| Banjar (Arabic script) | bjn_Arab | bjn_Arab |
|
||||||
|
| Banjar (Latin script) | bjn_Latn | bjn_Latn |
|
||||||
|
| Standard Tibetan | bo | bod_Tibt |
|
||||||
|
| Bosnian | bs | bos_Latn |
|
||||||
|
| Buginese | bug | bug_Latn |
|
||||||
|
| Bulgarian | bg | bul_Cyrl |
|
||||||
|
| Catalan | ca | cat_Latn |
|
||||||
|
| Cebuano | ceb | ceb_Latn |
|
||||||
|
| Czech | cs | ces_Latn |
|
||||||
|
| Chokwe | cjk | cjk_Latn |
|
||||||
|
| Central Kurdish | ckb | ckb_Arab |
|
||||||
|
| Crimean Tatar | crh | crh_Latn |
|
||||||
|
| Welsh | cy | cym_Latn |
|
||||||
|
| Danish | da | dan_Latn |
|
||||||
|
| German | de | deu_Latn |
|
||||||
|
| Southwestern Dinka | dik | dik_Latn |
|
||||||
|
| Dyula | dyu | dyu_Latn |
|
||||||
|
| Dzongkha | dz | dzo_Tibt |
|
||||||
|
| Greek | el | ell_Grek |
|
||||||
|
| English | en | eng_Latn |
|
||||||
|
| Esperanto | eo | epo_Latn |
|
||||||
|
| Estonian | et | est_Latn |
|
||||||
|
| Basque | eu | eus_Latn |
|
||||||
|
| Ewe | ee | ewe_Latn |
|
||||||
|
| Faroese | fo | fao_Latn |
|
||||||
|
| Fijian | fj | fij_Latn |
|
||||||
|
| Finnish | fi | fin_Latn |
|
||||||
|
| Fon | fon | fon_Latn |
|
||||||
|
| French | fr | fra_Latn |
|
||||||
|
| Friulian | fur-IT | fur_Latn |
|
||||||
|
| Nigerian Fulfulde | fuv | fuv_Latn |
|
||||||
|
| West Central Oromo | om | gaz_Latn |
|
||||||
|
| Scottish Gaelic | gd | gla_Latn |
|
||||||
|
| Irish | ga-IE | gle_Latn |
|
||||||
|
| Galician | gl | glg_Latn |
|
||||||
|
| Guarani | gn | grn_Latn |
|
||||||
|
| Gujarati | gu-IN | guj_Gujr |
|
||||||
|
| Haitian Creole | ht | hat_Latn |
|
||||||
|
| Hausa | ha | hau_Latn |
|
||||||
|
| Hebrew | he | heb_Hebr |
|
||||||
|
| Hindi | hi | hin_Deva |
|
||||||
|
| Chhattisgarhi | hne | hne_Deva |
|
||||||
|
| Croatian | hr | hrv_Latn |
|
||||||
|
| Hungarian | hu | hun_Latn |
|
||||||
|
| Armenian | hy-AM | hye_Armn |
|
||||||
|
| Igbo | ig | ibo_Latn |
|
||||||
|
| Ilocano | ilo | ilo_Latn |
|
||||||
|
| Indonesian | id | ind_Latn |
|
||||||
|
| Icelandic | is | isl_Latn |
|
||||||
|
| Italian | it | ita_Latn |
|
||||||
|
| Javanese | jv | jav_Latn |
|
||||||
|
| Japanese | ja | jpn_Jpan |
|
||||||
|
| Kabyle | kab | kab_Latn |
|
||||||
|
| Jingpho | kac | kac_Latn |
|
||||||
|
| Kamba | kam | kam_Latn |
|
||||||
|
| Kannada | kn | kan_Knda |
|
||||||
|
| Kashmiri (Arabic script) | kas_Arab | kas_Arab |
|
||||||
|
| Kashmiri (Devanagari script) | kas_Deva | kas_Deva |
|
||||||
|
| Georgian | ka | kat_Geor |
|
||||||
|
| Kazakh | kk | kaz_Cyrl |
|
||||||
|
| Kabiyè | kbp | kbp_Latn |
|
||||||
|
| Kabuverdianu | kea | kea_Latn |
|
||||||
|
| Halh Mongolian | mn | khk_Cyrl |
|
||||||
|
| Khmer | km | khm_Khmr |
|
||||||
|
| Kikuyu | ki | kik_Latn |
|
||||||
|
| Kinyarwanda | rw | kin_Latn |
|
||||||
|
| Kyrgyz | ky | kir_Cyrl |
|
||||||
|
| Kimbundu | kmb | kmb_Latn |
|
||||||
|
| Northern Kurdish | kmr | kmr_Latn |
|
||||||
|
| Central Kanuri (Arabic script) | knc_Arab | knc_Arab |
|
||||||
|
| Central Kanuri (Latin script) | knc_Latn | knc_Latn |
|
||||||
|
| Kikongo | kg | kon_Latn |
|
||||||
|
| Korean | ko | kor_Hang |
|
||||||
|
| Lao | lo | lao_Laoo |
|
||||||
|
| Ligurian | lij | lij_Latn |
|
||||||
|
| Limburgish | li | lim_Latn |
|
||||||
|
| Lingala | ln | lin_Latn |
|
||||||
|
| Lithuanian | lt | lit_Latn |
|
||||||
|
| Lombard | lmo | lmo_Latn |
|
||||||
|
| Latgalian | ltg | ltg_Latn |
|
||||||
|
| Luxembourgish | lb | ltz_Latn |
|
||||||
|
| Luba-Kasai | lua | lua_Latn |
|
||||||
|
| Ganda | lg | lug_Latn |
|
||||||
|
| Luo | luo | luo_Latn |
|
||||||
|
| Mizo | lus | lus_Latn |
|
||||||
|
| Standard Latvian | lv | lvs_Latn |
|
||||||
|
| Magahi | mag | mag_Deva |
|
||||||
|
| Maithili | mai | mai_Deva |
|
||||||
|
| Malayalam | ml-IN | mal_Mlym |
|
||||||
|
| Marathi | mr | mar_Deva |
|
||||||
|
| Minangkabau (Arabic script) | min_Arab | min_Arab |
|
||||||
|
| Minangkabau (Latin script) | min_Latn | min_Latn |
|
||||||
|
| Macedonian | mk | mkd_Cyrl |
|
||||||
|
| Maltese | mt | mlt_Latn |
|
||||||
|
| Meitei (Bengali script) | mni | mni_Beng |
|
||||||
|
| Mossi | mos | mos_Latn |
|
||||||
|
| Maori | mi | mri_Latn |
|
||||||
|
| Burmese | my | mya_Mymr |
|
||||||
|
| Dutch | nl | nld_Latn |
|
||||||
|
| Norwegian Nynorsk | nn-NO | nno_Latn |
|
||||||
|
| Norwegian Bokmål | nb | nob_Latn |
|
||||||
|
| Nepali | ne-NP | npi_Deva |
|
||||||
|
| Northern Sotho | nso | nso_Latn |
|
||||||
|
| Nuer | nus | nus_Latn |
|
||||||
|
| Nyanja | ny | nya_Latn |
|
||||||
|
| Occitan | oc | oci_Latn |
|
||||||
|
| Odia | or | ory_Orya |
|
||||||
|
| Pangasinan | pag | pag_Latn |
|
||||||
|
| Eastern Panjabi | pa | pan_Guru |
|
||||||
|
| Papiamento | pap | pap_Latn |
|
||||||
|
| Southern Pashto | pbt | pbt_Arab |
|
||||||
|
| Western Persian | fa | pes_Arab |
|
||||||
|
| Plateau Malagasy | mg | plt_Latn |
|
||||||
|
| Polish | pl | pol_Latn |
|
||||||
|
| Portuguese | pt-PT | por_Latn |
|
||||||
|
| Dari | fa-AF | prs_Arab |
|
||||||
|
| Ayacucho Quechua | qu | quy_Latn |
|
||||||
|
| Romanian | ro | ron_Latn |
|
||||||
|
| Rundi | rn | run_Latn |
|
||||||
|
| Russian | ru | rus_Cyrl |
|
||||||
|
| Sango | sg | sag_Latn |
|
||||||
|
| Sanskrit | sa | san_Deva |
|
||||||
|
| Santali | sat | sat_Olck |
|
||||||
|
| Sicilian | scn | scn_Latn |
|
||||||
|
| Shan | shn | shn_Mymr |
|
||||||
|
| Sinhala | si-LK | sin_Sinh |
|
||||||
|
| Slovak | sk | slk_Latn |
|
||||||
|
| Slovenian | sl | slv_Latn |
|
||||||
|
| Samoan | sm | smo_Latn |
|
||||||
|
| Shona | sn | sna_Latn |
|
||||||
|
| Sindhi | sd | snd_Arab |
|
||||||
|
| Somali | so | som_Latn |
|
||||||
|
| Southern Sotho | st | sot_Latn |
|
||||||
|
| Spanish | es-ES | spa_Latn |
|
||||||
|
| Sardinian | sc | srd_Latn |
|
||||||
|
| Serbian | sr | srp_Cyrl |
|
||||||
|
| Swati | ss | ssw_Latn |
|
||||||
|
| Sundanese | su | sun_Latn |
|
||||||
|
| Swedish | sv-SE | swe_Latn |
|
||||||
|
| Swahili | sw | swh_Latn |
|
||||||
|
| Silesian | szl | szl_Latn |
|
||||||
|
| Tamil | ta | tam_Taml |
|
||||||
|
| Tamasheq (Latin script) | taq_Latn | taq_Latn |
|
||||||
|
| Tamasheq (Tifinagh script) | taq_Tfng | taq_Tfng |
|
||||||
|
| Tatar | tt-RU | tat_Cyrl |
|
||||||
|
| Telugu | te | tel_Telu |
|
||||||
|
| Tajik | tg | tgk_Cyrl |
|
||||||
|
| Tagalog | tl | tgl_Latn |
|
||||||
|
| Thai | th | tha_Thai |
|
||||||
|
| Tigrinya | ti | tir_Ethi |
|
||||||
|
| Tok Pisin | tpi | tpi_Latn |
|
||||||
|
| Tswana | tn | tsn_Latn |
|
||||||
|
| Tsonga | ts | tso_Latn |
|
||||||
|
| Turkmen | tk | tuk_Latn |
|
||||||
|
| Tumbuka | tum | tum_Latn |
|
||||||
|
| Turkish | tr | tur_Latn |
|
||||||
|
| Twi | tw | twi_Latn |
|
||||||
|
| Central Atlas Tamazight | tzm | tzm_Tfng |
|
||||||
|
| Uyghur | ug | uig_Arab |
|
||||||
|
| Ukrainian | uk | ukr_Cyrl |
|
||||||
|
| Umbundu | umb | umb_Latn |
|
||||||
|
| Urdu | ur | urd_Arab |
|
||||||
|
| Northern Uzbek | uz | uzn_Latn |
|
||||||
|
| Venetian | vec | vec_Latn |
|
||||||
|
| Vietnamese | vi | vie_Latn |
|
||||||
|
| Waray | war | war_Latn |
|
||||||
|
| Wolof | wo | wol_Latn |
|
||||||
|
| Xhosa | xh | xho_Latn |
|
||||||
|
| Eastern Yiddish | yi | ydd_Hebr |
|
||||||
|
| Yoruba | yo | yor_Latn |
|
||||||
|
| Yue Chinese | yue | yue_Hant |
|
||||||
|
| Chinese (Simplified) | zh-CN | zho_Hans |
|
||||||
|
| Chinese (Traditional) | zh-TW | zho_Hant |
|
||||||
|
| Standard Malay | ms | zsm_Latn |
|
||||||
|
| Zulu | zu | zul_Latn |
|
||||||
|
|
||||||
|
## Special Features
|
||||||
|
|
||||||
|
### Multiple Script Support
|
||||||
|
Several languages are available in multiple scripts (e.g., Arabic and Latin):
|
||||||
|
- **Acehnese**: Arabic (`ace_Arab`) and Latin (`ace_Latn`)
|
||||||
|
- **Banjar**: Arabic (`bjn_Arab`) and Latin (`bjn_Latn`)
|
||||||
|
- **Kashmiri**: Arabic (`kas_Arab`) and Devanagari (`kas_Deva`)
|
||||||
|
- **Minangkabau**: Arabic (`min_Arab`) and Latin (`min_Latn`)
|
||||||
|
- **Tamasheq**: Latin (`taq_Latn`) and Tifinagh (`taq_Tfng`)
|
||||||
|
- **Central Kanuri**: Arabic (`knc_Arab`) and Latin (`knc_Latn`)
|
||||||
69
pyproject.toml
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "whisperlivekit"
|
||||||
|
version = "0.2.13"
|
||||||
|
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||||
|
readme = "README.md"
|
||||||
|
authors = [
|
||||||
|
{ name = "Quentin Fuxa" }
|
||||||
|
]
|
||||||
|
license = { file = "LICENSE" }
|
||||||
|
requires-python = ">=3.9"
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Programming Language :: Python :: 3.14",
|
||||||
|
"Programming Language :: Python :: 3.15",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio :: Speech"
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"fastapi",
|
||||||
|
"librosa",
|
||||||
|
"soundfile",
|
||||||
|
"faster-whisper",
|
||||||
|
"uvicorn",
|
||||||
|
"websockets",
|
||||||
|
"torchaudio>=2.0.0",
|
||||||
|
"torch>=2.0.0",
|
||||||
|
"tqdm",
|
||||||
|
"tiktoken",
|
||||||
|
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
translation = ["nllw"]
|
||||||
|
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = [
|
||||||
|
"whisperlivekit",
|
||||||
|
"whisperlivekit.diarization",
|
||||||
|
"whisperlivekit.simul_whisper",
|
||||||
|
"whisperlivekit.simul_whisper.whisper",
|
||||||
|
"whisperlivekit.simul_whisper.whisper.assets",
|
||||||
|
"whisperlivekit.simul_whisper.whisper.normalizers",
|
||||||
|
"whisperlivekit.web",
|
||||||
|
"whisperlivekit.whisper_streaming_custom",
|
||||||
|
"whisperlivekit.vad_models"
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||||
|
"whisperlivekit.simul_whisper.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||||
|
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]
|
||||||
54
setup.py
@@ -1,54 +0,0 @@
|
|||||||
from setuptools import setup, find_packages
|
|
||||||
setup(
|
|
||||||
name="whisperlivekit",
|
|
||||||
version="0.2.1",
|
|
||||||
description="Real-time, Fully Local Whisper's Speech-to-Text and Speaker Diarization",
|
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
|
||||||
long_description_content_type="text/markdown",
|
|
||||||
author="Quentin Fuxa",
|
|
||||||
url="https://github.com/QuentinFuxa/WhisperLiveKit",
|
|
||||||
packages=find_packages(),
|
|
||||||
install_requires=[
|
|
||||||
"fastapi",
|
|
||||||
"ffmpeg-python",
|
|
||||||
"librosa",
|
|
||||||
"soundfile",
|
|
||||||
"faster-whisper",
|
|
||||||
"uvicorn",
|
|
||||||
"websockets",
|
|
||||||
],
|
|
||||||
extras_require={
|
|
||||||
"diarization": ["diart"],
|
|
||||||
"vac": ["torch"],
|
|
||||||
"sentence": ["mosestokenizer", "wtpsplit"],
|
|
||||||
"whisper": ["whisper"],
|
|
||||||
"whisper-timestamped": ["whisper-timestamped"],
|
|
||||||
"mlx-whisper": ["mlx-whisper"],
|
|
||||||
"openai": ["openai"],
|
|
||||||
"simulstreaming": [
|
|
||||||
"torch",
|
|
||||||
"tqdm",
|
|
||||||
"tiktoken",
|
|
||||||
"triton>=2.0.0,<3;platform_machine==\"x86_64\" and sys_platform==\"linux\" or sys_platform==\"linux2\"",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
package_data={
|
|
||||||
'whisperlivekit': ['web/*.html'],
|
|
||||||
'whisperlivekit.simul_whisper': ['dual_license_simulstreaming.md'],
|
|
||||||
},
|
|
||||||
entry_points={
|
|
||||||
'console_scripts': [
|
|
||||||
'whisperlivekit-server=whisperlivekit.basic_server:main',
|
|
||||||
],
|
|
||||||
},
|
|
||||||
classifiers=[
|
|
||||||
"Development Status :: 4 - Beta",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
|
||||||
],
|
|
||||||
python_requires=">=3.9",
|
|
||||||
)
|
|
||||||
38
sync_extension.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import shutil
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def sync_extension_files():
|
||||||
|
"""Copy core files from web directory to Chrome extension directory."""
|
||||||
|
|
||||||
|
web_dir = Path("whisperlivekit/web")
|
||||||
|
extension_dir = Path("chrome-extension")
|
||||||
|
|
||||||
|
files_to_sync = [
|
||||||
|
"live_transcription.html", "live_transcription.js", "live_transcription.css"
|
||||||
|
]
|
||||||
|
|
||||||
|
svg_files = [
|
||||||
|
"system_mode.svg",
|
||||||
|
"light_mode.svg",
|
||||||
|
"dark_mode.svg",
|
||||||
|
"settings.svg"
|
||||||
|
]
|
||||||
|
|
||||||
|
for file in files_to_sync:
|
||||||
|
src_path = web_dir / file
|
||||||
|
dest_path = extension_dir / file
|
||||||
|
|
||||||
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy2(src_path, dest_path)
|
||||||
|
|
||||||
|
for svg_file in svg_files:
|
||||||
|
src_path = web_dir / "src" / svg_file
|
||||||
|
dest_path = extension_dir / "web" / "src" / svg_file
|
||||||
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy2(src_path, dest_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
sync_extension_files()
|
||||||
@@ -1,5 +1,13 @@
|
|||||||
from .core import TranscriptionEngine
|
|
||||||
from .audio_processor import AudioProcessor
|
from .audio_processor import AudioProcessor
|
||||||
from .web.web_interface import get_web_interface_html
|
from .core import TranscriptionEngine
|
||||||
from .parse_args import parse_args
|
from .parse_args import parse_args
|
||||||
__all__ = ['TranscriptionEngine', 'AudioProcessor', 'get_web_interface_html', 'parse_args']
|
from .web.web_interface import get_web_interface_html, get_inline_ui_html
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TranscriptionEngine",
|
||||||
|
"AudioProcessor",
|
||||||
|
"parse_args",
|
||||||
|
"get_web_interface_html",
|
||||||
|
"get_inline_ui_html",
|
||||||
|
"download_simulstreaming_backend",
|
||||||
|
]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_web_interface_html, parse_args
|
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ args = parse_args()
|
|||||||
transcription_engine = None
|
transcription_engine = None
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global transcription_engine
|
global transcription_engine
|
||||||
transcription_engine = TranscriptionEngine(
|
transcription_engine = TranscriptionEngine(
|
||||||
**vars(args),
|
**vars(args),
|
||||||
@@ -33,21 +33,21 @@ app.add_middleware(
|
|||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def get():
|
async def get():
|
||||||
return HTMLResponse(get_web_interface_html())
|
return HTMLResponse(get_inline_ui_html())
|
||||||
|
|
||||||
|
|
||||||
async def handle_websocket_results(websocket, results_generator):
|
async def handle_websocket_results(websocket, results_generator):
|
||||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||||
try:
|
try:
|
||||||
async for response in results_generator:
|
async for response in results_generator:
|
||||||
await websocket.send_json(response)
|
await websocket.send_json(response.to_dict())
|
||||||
# when the results_generator finishes it means all audio has been processed
|
# when the results_generator finishes it means all audio has been processed
|
||||||
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
||||||
await websocket.send_json({"type": "ready_to_stop"})
|
await websocket.send_json({"type": "ready_to_stop"})
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error in WebSocket results handler: {e}")
|
logger.exception(f"Error in WebSocket results handler: {e}")
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/asr")
|
@app.websocket("/asr")
|
||||||
@@ -58,6 +58,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
)
|
)
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
logger.info("WebSocket connection opened.")
|
logger.info("WebSocket connection opened.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send config to client: {e}")
|
||||||
|
|
||||||
results_generator = await audio_processor.create_tasks()
|
results_generator = await audio_processor.create_tasks()
|
||||||
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||||
@@ -113,6 +118,8 @@ def main():
|
|||||||
|
|
||||||
if ssl_kwargs:
|
if ssl_kwargs:
|
||||||
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
||||||
|
if args.forwarded_allow_ips:
|
||||||
|
uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips }
|
||||||
|
|
||||||
uvicorn.run(**uvicorn_kwargs)
|
uvicorn.run(**uvicorn_kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
try:
|
try:
|
||||||
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
from whisperlivekit.whisper_streaming_custom.whisper_online import backend_factory
|
||||||
|
from whisperlivekit.whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from .whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
from .whisper_streaming_custom.whisper_online import backend_factory
|
||||||
|
from .whisper_streaming_custom.online_asr import OnlineASRProcessor
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def update_with_kwargs(_dict, kwargs):
|
||||||
|
_dict.update({
|
||||||
|
k: v for k, v in kwargs.items() if k in _dict
|
||||||
|
})
|
||||||
|
return _dict
|
||||||
|
|
||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -18,75 +26,157 @@ class TranscriptionEngine:
|
|||||||
if TranscriptionEngine._initialized:
|
if TranscriptionEngine._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
defaults = {
|
global_params = {
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 8000,
|
"port": 8000,
|
||||||
"warmup_file": None,
|
|
||||||
"confidence_validation": False,
|
|
||||||
"diarization": False,
|
"diarization": False,
|
||||||
"punctuation_split": False,
|
"punctuation_split": False,
|
||||||
|
"target_language": "",
|
||||||
|
"vac": True,
|
||||||
|
"vac_onnx": False,
|
||||||
|
"vac_chunk_size": 0.04,
|
||||||
|
"log_level": "DEBUG",
|
||||||
|
"ssl_certfile": None,
|
||||||
|
"ssl_keyfile": None,
|
||||||
|
"forwarded_allow_ips": None,
|
||||||
|
"transcription": True,
|
||||||
|
"vad": True,
|
||||||
|
"pcm_input": False,
|
||||||
|
"disable_punctuation_split" : False,
|
||||||
|
"diarization_backend": "sortformer",
|
||||||
|
}
|
||||||
|
global_params = update_with_kwargs(global_params, kwargs)
|
||||||
|
|
||||||
|
transcription_common_params = {
|
||||||
|
"backend": "simulstreaming",
|
||||||
|
"warmup_file": None,
|
||||||
"min_chunk_size": 0.5,
|
"min_chunk_size": 0.5,
|
||||||
"model": "tiny",
|
"model_size": "tiny",
|
||||||
"model_cache_dir": None,
|
"model_cache_dir": None,
|
||||||
"model_dir": None,
|
"model_dir": None,
|
||||||
"lan": "auto",
|
"lan": "auto",
|
||||||
"task": "transcribe",
|
"task": "transcribe",
|
||||||
"backend": "faster-whisper",
|
|
||||||
"vac": False,
|
|
||||||
"vac_chunk_size": 0.04,
|
|
||||||
"buffer_trimming": "segment",
|
|
||||||
"buffer_trimming_sec": 15,
|
|
||||||
"log_level": "DEBUG",
|
|
||||||
"ssl_certfile": None,
|
|
||||||
"ssl_keyfile": None,
|
|
||||||
"transcription": True,
|
|
||||||
"vad": True,
|
|
||||||
"segmentation_model": "pyannote/segmentation-3.0",
|
|
||||||
"embedding_model": "pyannote/embedding",
|
|
||||||
# simulstreaming params:
|
|
||||||
"frame_threshold": 25,
|
|
||||||
"beams": 1,
|
|
||||||
"decoder_type": None,
|
|
||||||
"audio_max_len": 30.0,
|
|
||||||
"audio_min_len": 0.0,
|
|
||||||
"cif_ckpt_path": None,
|
|
||||||
"never_fire": False,
|
|
||||||
"init_prompt": None,
|
|
||||||
"static_init_prompt": None,
|
|
||||||
"max_context_tokens": None,
|
|
||||||
"model_path": './base.pt',
|
|
||||||
}
|
}
|
||||||
|
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
|
||||||
|
|
||||||
config_dict = {**defaults, **kwargs}
|
if transcription_common_params['model_size'].endswith(".en"):
|
||||||
|
transcription_common_params["lan"] = "en"
|
||||||
if 'no_transcription' in kwargs:
|
if 'no_transcription' in kwargs:
|
||||||
config_dict['transcription'] = not kwargs['no_transcription']
|
global_params['transcription'] = not global_params['no_transcription']
|
||||||
if 'no_vad' in kwargs:
|
if 'no_vad' in kwargs:
|
||||||
config_dict['vad'] = not kwargs['no_vad']
|
global_params['vad'] = not kwargs['no_vad']
|
||||||
|
if 'no_vac' in kwargs:
|
||||||
config_dict.pop('no_transcription', None)
|
global_params['vac'] = not kwargs['no_vac']
|
||||||
config_dict.pop('no_vad', None)
|
|
||||||
|
|
||||||
if 'language' in kwargs:
|
self.args = Namespace(**{**global_params, **transcription_common_params})
|
||||||
config_dict['lan'] = kwargs['language']
|
|
||||||
config_dict.pop('language', None)
|
|
||||||
|
|
||||||
self.args = Namespace(**config_dict)
|
|
||||||
|
|
||||||
self.asr = None
|
self.asr = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.diarization = None
|
self.diarization = None
|
||||||
|
self.vac_model = None
|
||||||
|
|
||||||
|
if self.args.vac:
|
||||||
|
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
||||||
|
# Use ONNX if specified, otherwise use JIT (default)
|
||||||
|
use_onnx = kwargs.get('vac_onnx', False)
|
||||||
|
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||||
|
|
||||||
if self.args.transcription:
|
if self.args.transcription:
|
||||||
self.asr, self.tokenizer = backend_factory(self.args)
|
if self.args.backend == "simulstreaming":
|
||||||
warmup_asr(self.asr, self.args.warmup_file)
|
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||||
|
|
||||||
|
simulstreaming_params = {
|
||||||
|
"disable_fast_encoder": False,
|
||||||
|
"custom_alignment_heads": None,
|
||||||
|
"frame_threshold": 25,
|
||||||
|
"beams": 1,
|
||||||
|
"decoder_type": None,
|
||||||
|
"audio_max_len": 20.0,
|
||||||
|
"audio_min_len": 0.0,
|
||||||
|
"cif_ckpt_path": None,
|
||||||
|
"never_fire": False,
|
||||||
|
"init_prompt": None,
|
||||||
|
"static_init_prompt": None,
|
||||||
|
"max_context_tokens": None,
|
||||||
|
"model_path": './base.pt',
|
||||||
|
"preload_model_count": 1,
|
||||||
|
}
|
||||||
|
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
||||||
|
|
||||||
|
self.tokenizer = None
|
||||||
|
self.asr = SimulStreamingASR(
|
||||||
|
**transcription_common_params, **simulstreaming_params
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
whisperstreaming_params = {
|
||||||
|
"buffer_trimming": "segment",
|
||||||
|
"confidence_validation": False,
|
||||||
|
"buffer_trimming_sec": 15,
|
||||||
|
}
|
||||||
|
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
||||||
|
|
||||||
|
self.asr = backend_factory(
|
||||||
|
**transcription_common_params, **whisperstreaming_params
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.diarization:
|
if self.args.diarization:
|
||||||
from whisperlivekit.diarization.diarization_online import DiartDiarization
|
if self.args.diarization_backend == "diart":
|
||||||
self.diarization = DiartDiarization(
|
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||||
block_duration=self.args.min_chunk_size,
|
diart_params = {
|
||||||
segmentation_model_name=self.args.segmentation_model,
|
"segmentation_model": "pyannote/segmentation-3.0",
|
||||||
embedding_model_name=self.args.embedding_model
|
"embedding_model": "pyannote/embedding",
|
||||||
)
|
}
|
||||||
|
diart_params = update_with_kwargs(diart_params, kwargs)
|
||||||
|
self.diarization_model = DiartDiarization(
|
||||||
|
block_duration=self.args.min_chunk_size,
|
||||||
|
**diart_params
|
||||||
|
)
|
||||||
|
elif self.args.diarization_backend == "sortformer":
|
||||||
|
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||||
|
self.diarization_model = SortformerDiarization()
|
||||||
|
|
||||||
|
self.translation_model = None
|
||||||
|
if self.args.target_language:
|
||||||
|
if self.args.lan == 'auto' and self.args.backend != "simulstreaming":
|
||||||
|
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from nllw import load_model
|
||||||
|
except:
|
||||||
|
raise Exception('To use translation, you must install nllw: `pip install nllw`')
|
||||||
|
translation_params = {
|
||||||
|
"nllb_backend": "transformers",
|
||||||
|
"nllb_size": "600M"
|
||||||
|
}
|
||||||
|
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||||
|
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
||||||
TranscriptionEngine._initialized = True
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
|
|
||||||
|
def online_factory(args, asr):
|
||||||
|
if args.backend == "simulstreaming":
|
||||||
|
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||||
|
online = SimulStreamingOnlineProcessor(asr)
|
||||||
|
else:
|
||||||
|
online = OnlineASRProcessor(asr)
|
||||||
|
return online
|
||||||
|
|
||||||
|
|
||||||
|
def online_diarization_factory(args, diarization_backend):
|
||||||
|
if args.diarization_backend == "diart":
|
||||||
|
online = diarization_backend
|
||||||
|
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
||||||
|
|
||||||
|
if args.diarization_backend == "sortformer":
|
||||||
|
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
|
||||||
|
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||||
|
return online
|
||||||
|
|
||||||
|
|
||||||
|
def online_translation_factory(args, translation_model):
|
||||||
|
#should be at speaker level in the future:
|
||||||
|
#one shared nllb model for all speaker
|
||||||
|
#one tokenizer per speaker/language
|
||||||
|
from nllw import OnlineTranslation
|
||||||
|
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ class DiarizationObserver(Observer):
|
|||||||
self.speaker_segments = []
|
self.speaker_segments = []
|
||||||
self.processed_time = 0
|
self.processed_time = 0
|
||||||
self.segment_lock = threading.Lock()
|
self.segment_lock = threading.Lock()
|
||||||
|
self.global_time_offset = 0.0
|
||||||
|
|
||||||
def on_next(self, value: Tuple[Annotation, Any]):
|
def on_next(self, value: Tuple[Annotation, Any]):
|
||||||
annotation, audio = value
|
annotation, audio = value
|
||||||
@@ -49,8 +50,8 @@ class DiarizationObserver(Observer):
|
|||||||
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
||||||
self.speaker_segments.append(SpeakerSegment(
|
self.speaker_segments.append(SpeakerSegment(
|
||||||
speaker=speaker,
|
speaker=speaker,
|
||||||
start=start,
|
start=start + self.global_time_offset,
|
||||||
end=end
|
end=end + self.global_time_offset
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
logger.debug("\nNo speakers detected in this segment")
|
logger.debug("\nNo speakers detected in this segment")
|
||||||
@@ -165,7 +166,7 @@ class WebSocketAudioSource(AudioSource):
|
|||||||
|
|
||||||
|
|
||||||
class DiartDiarization:
|
class DiartDiarization:
|
||||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 0.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "speechbrain/spkrec-ecapa-voxceleb"):
|
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
|
||||||
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||||
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||||
|
|
||||||
@@ -199,6 +200,9 @@ class DiartDiarization:
|
|||||||
self.inference.attach_observers(self.observer)
|
self.inference.attach_observers(self.observer)
|
||||||
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
||||||
|
|
||||||
|
def insert_silence(self, silence_duration):
|
||||||
|
self.observer.global_time_offset += silence_duration
|
||||||
|
|
||||||
async def diarize(self, pcm_array: np.ndarray):
|
async def diarize(self, pcm_array: np.ndarray):
|
||||||
"""
|
"""
|
||||||
Process audio data for diarization.
|
Process audio data for diarization.
|
||||||
@@ -206,15 +210,14 @@ class DiartDiarization:
|
|||||||
"""
|
"""
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.push_audio(pcm_array)
|
self.custom_source.push_audio(pcm_array)
|
||||||
self.observer.clear_old_segments()
|
# self.observer.clear_old_segments()
|
||||||
return self.observer.get_segments()
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the audio source."""
|
"""Close the audio source."""
|
||||||
if self.custom_source:
|
if self.custom_source:
|
||||||
self.custom_source.close()
|
self.custom_source.close()
|
||||||
|
|
||||||
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list, use_punctuation_split: bool = False) -> float:
|
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> float:
|
||||||
"""
|
"""
|
||||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||||
Uses the segments collected by the observer.
|
Uses the segments collected by the observer.
|
||||||
@@ -231,85 +234,82 @@ class DiartDiarization:
|
|||||||
|
|
||||||
if not self.lag_diart and segments and tokens:
|
if not self.lag_diart and segments and tokens:
|
||||||
self.lag_diart = segments[0].start - tokens[0].start
|
self.lag_diart = segments[0].start - tokens[0].start
|
||||||
for token in tokens:
|
|
||||||
for segment in segments:
|
|
||||||
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
|
||||||
token.speaker = extract_number(segment.speaker) + 1
|
|
||||||
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
|
||||||
|
|
||||||
if use_punctuation_split and len(tokens) > 1:
|
if not use_punctuation_split:
|
||||||
punctuation_marks = {'.', '!', '?'}
|
for token in tokens:
|
||||||
|
for segment in segments:
|
||||||
print("Here are the tokens:",
|
if not (segment.end <= token.start + self.lag_diart or segment.start >= token.end + self.lag_diart):
|
||||||
[(t.text, t.start, t.end, t.speaker) for t in tokens[:10]])
|
token.speaker = extract_number(segment.speaker) + 1
|
||||||
|
else:
|
||||||
segment_map = []
|
tokens = add_speaker_to_tokens(segments, tokens)
|
||||||
for segment in segments:
|
return tokens
|
||||||
speaker_num = extract_number(segment.speaker) + 1
|
|
||||||
segment_map.append((segment.start, segment.end, speaker_num))
|
|
||||||
segment_map.sort(key=lambda x: x[0])
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
while i < len(tokens):
|
|
||||||
current_token = tokens[i]
|
|
||||||
|
|
||||||
is_sentence_end = False
|
|
||||||
if current_token.text and current_token.text.strip():
|
|
||||||
text = current_token.text.strip()
|
|
||||||
if text[-1] in punctuation_marks:
|
|
||||||
is_sentence_end = True
|
|
||||||
logger.debug(f"Token {i} ends sentence: '{current_token.text}' at {current_token.end:.2f}s")
|
|
||||||
|
|
||||||
if is_sentence_end and current_token.speaker != -1:
|
|
||||||
punctuation_time = current_token.end
|
|
||||||
current_speaker = current_token.speaker
|
|
||||||
|
|
||||||
j = i + 1
|
|
||||||
next_sentence_tokens = []
|
|
||||||
while j < len(tokens):
|
|
||||||
next_token = tokens[j]
|
|
||||||
next_sentence_tokens.append(j)
|
|
||||||
|
|
||||||
# Check if this token ends the next sentence
|
|
||||||
if next_token.text and next_token.text.strip():
|
|
||||||
if next_token.text.strip()[-1] in punctuation_marks:
|
|
||||||
break
|
|
||||||
j += 1
|
|
||||||
|
|
||||||
if next_sentence_tokens:
|
|
||||||
speaker_times = {}
|
|
||||||
|
|
||||||
for idx in next_sentence_tokens:
|
|
||||||
token = tokens[idx]
|
|
||||||
# Find which segments overlap with this token
|
|
||||||
for seg_start, seg_end, seg_speaker in segment_map:
|
|
||||||
if not (seg_end <= token.start or seg_start >= token.end):
|
|
||||||
# Calculate overlap duration
|
|
||||||
overlap_start = max(seg_start, token.start)
|
|
||||||
overlap_end = min(seg_end, token.end)
|
|
||||||
overlap_duration = overlap_end - overlap_start
|
|
||||||
|
|
||||||
if seg_speaker not in speaker_times:
|
|
||||||
speaker_times[seg_speaker] = 0
|
|
||||||
speaker_times[seg_speaker] += overlap_duration
|
|
||||||
|
|
||||||
if speaker_times:
|
|
||||||
dominant_speaker = max(speaker_times.items(), key=lambda x: x[1])[0]
|
|
||||||
|
|
||||||
if dominant_speaker != current_speaker:
|
|
||||||
logger.debug(f" Speaker change after punctuation: {current_speaker} → {dominant_speaker}")
|
|
||||||
|
|
||||||
for idx in next_sentence_tokens:
|
|
||||||
if tokens[idx].speaker != dominant_speaker:
|
|
||||||
logger.debug(f" Reassigning token {idx} ('{tokens[idx].text}') to Speaker {dominant_speaker}")
|
|
||||||
tokens[idx].speaker = dominant_speaker
|
|
||||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
|
||||||
else:
|
|
||||||
for idx in next_sentence_tokens:
|
|
||||||
if tokens[idx].speaker == -1:
|
|
||||||
tokens[idx].speaker = current_speaker
|
|
||||||
end_attributed_speaker = max(tokens[idx].end, end_attributed_speaker)
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return end_attributed_speaker
|
def concatenate_speakers(segments):
|
||||||
|
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||||
|
for segment in segments:
|
||||||
|
speaker = extract_number(segment.speaker) + 1
|
||||||
|
if segments_concatenated[-1]['speaker'] != speaker:
|
||||||
|
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||||
|
else:
|
||||||
|
segments_concatenated[-1]['end'] = segment.end
|
||||||
|
# print("Segments concatenated:")
|
||||||
|
# for entry in segments_concatenated:
|
||||||
|
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
|
||||||
|
return segments_concatenated
|
||||||
|
|
||||||
|
|
||||||
|
def add_speaker_to_tokens(segments, tokens):
|
||||||
|
"""
|
||||||
|
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
|
||||||
|
"""
|
||||||
|
punctuation_marks = {'.', '!', '?'}
|
||||||
|
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||||
|
segments_concatenated = concatenate_speakers(segments)
|
||||||
|
for ind, segment in enumerate(segments_concatenated):
|
||||||
|
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||||
|
if punctuation_token.start > segment['end']:
|
||||||
|
after_length = punctuation_token.start - segment['end']
|
||||||
|
before_length = segment['end'] - punctuation_tokens[i - 1].end
|
||||||
|
if before_length > after_length:
|
||||||
|
segment['end'] = punctuation_token.start
|
||||||
|
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||||
|
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||||
|
else:
|
||||||
|
segment['end'] = punctuation_tokens[i - 1].end
|
||||||
|
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||||
|
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||||
|
break
|
||||||
|
|
||||||
|
last_end = 0.0
|
||||||
|
for token in tokens:
|
||||||
|
start = max(last_end + 0.01, token.start)
|
||||||
|
token.start = start
|
||||||
|
token.end = max(start, token.end)
|
||||||
|
last_end = token.end
|
||||||
|
|
||||||
|
ind_last_speaker = 0
|
||||||
|
for segment in segments_concatenated:
|
||||||
|
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||||
|
if token.end <= segment['end']:
|
||||||
|
token.speaker = segment['speaker']
|
||||||
|
ind_last_speaker = i + 1
|
||||||
|
# print(
|
||||||
|
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
|
||||||
|
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
|
||||||
|
# )
|
||||||
|
elif token.start > segment['end']:
|
||||||
|
break
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_tokens(tokens):
|
||||||
|
conversation = [{"speaker": -1, "text": ""}]
|
||||||
|
for token in tokens:
|
||||||
|
speaker = conversation[-1]['speaker']
|
||||||
|
if token.speaker != speaker:
|
||||||
|
conversation.append({"speaker": token.speaker, "text": token.text})
|
||||||
|
else:
|
||||||
|
conversation[-1]['text'] += token.text
|
||||||
|
print("Conversation:")
|
||||||
|
for entry in conversation:
|
||||||
|
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
||||||
466
whisperlivekit/diarization/sortformer_backend.py
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import wave
|
||||||
|
from typing import List, Optional
|
||||||
|
from queue import SimpleQueue, Empty
|
||||||
|
|
||||||
|
from whisperlivekit.timed_objects import SpeakerSegment
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||||
|
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||||
|
except ImportError:
|
||||||
|
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingSortformerState:
|
||||||
|
"""
|
||||||
|
This class creates a class instance that will be used to store the state of the
|
||||||
|
streaming Sortformer model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||||
|
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||||
|
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||||
|
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||||
|
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||||
|
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||||
|
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||||
|
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||||
|
n_sil_frames (torch.Tensor): Number of silence frames
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.spkcache = None # Speaker cache to store embeddings from start
|
||||||
|
self.spkcache_lengths = None
|
||||||
|
self.spkcache_preds = None # speaker cache predictions
|
||||||
|
self.fifo = None # to save the embedding from the latest chunks
|
||||||
|
self.fifo_lengths = None
|
||||||
|
self.fifo_preds = None
|
||||||
|
self.spk_perm = None
|
||||||
|
self.mean_sil_emb = None
|
||||||
|
self.n_sil_frames = None
|
||||||
|
|
||||||
|
|
||||||
|
class SortformerDiarization:
|
||||||
|
def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
|
||||||
|
"""
|
||||||
|
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
|
||||||
|
"""
|
||||||
|
self._load_model(model_name)
|
||||||
|
|
||||||
|
def _load_model(self, model_name: str):
|
||||||
|
"""Load and configure the Sortformer model for streaming."""
|
||||||
|
try:
|
||||||
|
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
|
||||||
|
self.diar_model.eval()
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.diar_model.to(device)
|
||||||
|
|
||||||
|
## to test
|
||||||
|
# for name, param in self.diar_model.named_parameters():
|
||||||
|
# if param.device != device:
|
||||||
|
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
|
||||||
|
|
||||||
|
logger.info(f"Using {device.type.upper()} for Sortformer model")
|
||||||
|
|
||||||
|
self.diar_model.sortformer_modules.chunk_len = 10
|
||||||
|
self.diar_model.sortformer_modules.subsampling_factor = 10
|
||||||
|
self.diar_model.sortformer_modules.chunk_right_context = 0
|
||||||
|
self.diar_model.sortformer_modules.chunk_left_context = 10
|
||||||
|
self.diar_model.sortformer_modules.spkcache_len = 188
|
||||||
|
self.diar_model.sortformer_modules.fifo_len = 188
|
||||||
|
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
||||||
|
self.diar_model.sortformer_modules.log = False
|
||||||
|
self.diar_model.sortformer_modules._check_streaming_parameters()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load Sortformer model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
class SortformerDiarizationOnline:
|
||||||
|
def __init__(self, shared_model, sample_rate: int = 16000):
|
||||||
|
"""
|
||||||
|
Initialize the streaming Sortformer diarization system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate: Audio sample rate (default: 16000)
|
||||||
|
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
||||||
|
"""
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.speaker_segments = []
|
||||||
|
self.buffer_audio = np.array([], dtype=np.float32)
|
||||||
|
self.segment_lock = threading.Lock()
|
||||||
|
self.global_time_offset = 0.0
|
||||||
|
self.processed_time = 0.0
|
||||||
|
self.debug = False
|
||||||
|
|
||||||
|
self.diar_model = shared_model.diar_model
|
||||||
|
|
||||||
|
self.audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||||
|
window_size=0.025,
|
||||||
|
normalize="NA",
|
||||||
|
n_fft=512,
|
||||||
|
features=128,
|
||||||
|
pad_to=0
|
||||||
|
)
|
||||||
|
self.audio2mel.to(self.diar_model.device)
|
||||||
|
|
||||||
|
self.chunk_duration_seconds = (
|
||||||
|
self.diar_model.sortformer_modules.chunk_len *
|
||||||
|
self.diar_model.sortformer_modules.subsampling_factor *
|
||||||
|
self.diar_model.preprocessor._cfg.window_stride
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_streaming_state()
|
||||||
|
|
||||||
|
self._previous_chunk_features = None
|
||||||
|
self._chunk_index = 0
|
||||||
|
self._len_prediction = None
|
||||||
|
|
||||||
|
# Audio buffer to store PCM chunks for debugging
|
||||||
|
self.audio_buffer = []
|
||||||
|
|
||||||
|
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
|
||||||
|
self.audio_chunk_buffer = []
|
||||||
|
self.accumulated_duration = 0.0
|
||||||
|
|
||||||
|
logger.info("SortformerDiarization initialized successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def _init_streaming_state(self):
|
||||||
|
"""Initialize the streaming state for the model."""
|
||||||
|
batch_size = 1
|
||||||
|
device = self.diar_model.device
|
||||||
|
|
||||||
|
self.streaming_state = StreamingSortformerState()
|
||||||
|
self.streaming_state.spkcache = torch.zeros(
|
||||||
|
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
self.streaming_state.spkcache_preds = torch.zeros(
|
||||||
|
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
|
self.streaming_state.fifo = torch.zeros(
|
||||||
|
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
|
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
||||||
|
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# Initialize total predictions tensor
|
||||||
|
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
||||||
|
|
||||||
|
def insert_silence(self, silence_duration: float):
|
||||||
|
"""
|
||||||
|
Insert silence period by adjusting the global time offset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
silence_duration: Duration of silence in seconds
|
||||||
|
"""
|
||||||
|
with self.segment_lock:
|
||||||
|
self.global_time_offset += silence_duration
|
||||||
|
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
|
||||||
|
|
||||||
|
async def diarize(self, pcm_array: np.ndarray):
|
||||||
|
"""
|
||||||
|
Process audio data for diarization in streaming fashion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_array: Audio data as numpy array
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.debug:
|
||||||
|
self.audio_buffer.append(pcm_array.copy())
|
||||||
|
|
||||||
|
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||||
|
|
||||||
|
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||||
|
if not len(self.buffer_audio) >= threshold:
|
||||||
|
return
|
||||||
|
|
||||||
|
audio = self.buffer_audio[:threshold]
|
||||||
|
self.buffer_audio = self.buffer_audio[threshold:]
|
||||||
|
|
||||||
|
device = self.diar_model.device
|
||||||
|
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
|
||||||
|
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
|
||||||
|
|
||||||
|
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||||
|
audio_signal_chunk, audio_signal_length_chunk
|
||||||
|
)
|
||||||
|
processed_signal_chunk = processed_signal_chunk.to(device)
|
||||||
|
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
|
||||||
|
|
||||||
|
if self._previous_chunk_features is not None:
|
||||||
|
to_add = self._previous_chunk_features[:, :, -99:].to(device)
|
||||||
|
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
|
||||||
|
else:
|
||||||
|
total_features = processed_signal_chunk.to(device)
|
||||||
|
|
||||||
|
self._previous_chunk_features = processed_signal_chunk.to(device)
|
||||||
|
|
||||||
|
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
left_offset = 8 if self._chunk_index > 0 else 0
|
||||||
|
right_offset = 8
|
||||||
|
|
||||||
|
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||||
|
processed_signal=chunk_feat_seq_t,
|
||||||
|
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
|
||||||
|
streaming_state=self.streaming_state,
|
||||||
|
total_preds=self.total_preds,
|
||||||
|
left_offset=left_offset,
|
||||||
|
right_offset=right_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert predictions to speaker segments
|
||||||
|
self._process_predictions()
|
||||||
|
|
||||||
|
self._chunk_index += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in diarize: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# TODO: Handle case when stream ends with partial buffer (accumulated_duration > 0 but < chunk_duration_seconds)
|
||||||
|
|
||||||
|
def _process_predictions(self):
|
||||||
|
"""Process model predictions and convert to speaker segments."""
|
||||||
|
try:
|
||||||
|
preds_np = self.total_preds[0].cpu().numpy()
|
||||||
|
active_speakers = np.argmax(preds_np, axis=1)
|
||||||
|
|
||||||
|
if self._len_prediction is None:
|
||||||
|
self._len_prediction = len(active_speakers)
|
||||||
|
|
||||||
|
# Get predictions for current chunk
|
||||||
|
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
||||||
|
current_chunk_preds = active_speakers[-self._len_prediction:]
|
||||||
|
|
||||||
|
with self.segment_lock:
|
||||||
|
# Process predictions into segments
|
||||||
|
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
|
||||||
|
|
||||||
|
for idx, spk in enumerate(current_chunk_preds):
|
||||||
|
start_time = base_time + idx * frame_duration
|
||||||
|
end_time = base_time + (idx + 1) * frame_duration
|
||||||
|
|
||||||
|
# Check if this continues the last segment or starts a new one
|
||||||
|
if (self.speaker_segments and
|
||||||
|
self.speaker_segments[-1].speaker == spk and
|
||||||
|
abs(self.speaker_segments[-1].end - start_time) < frame_duration * 0.5):
|
||||||
|
# Continue existing segment
|
||||||
|
self.speaker_segments[-1].end = end_time
|
||||||
|
else:
|
||||||
|
|
||||||
|
# Create new segment
|
||||||
|
self.speaker_segments.append(SpeakerSegment(
|
||||||
|
speaker=spk,
|
||||||
|
start=start_time,
|
||||||
|
end=end_time
|
||||||
|
))
|
||||||
|
|
||||||
|
# Update processed time
|
||||||
|
self.processed_time = max(self.processed_time, base_time + self.chunk_duration_seconds)
|
||||||
|
|
||||||
|
logger.debug(f"Processed chunk {self._chunk_index}, total segments: {len(self.speaker_segments)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing predictions: {e}")
|
||||||
|
|
||||||
|
def assign_speakers_to_tokens(self, tokens: list, use_punctuation_split: bool = False) -> list:
|
||||||
|
"""
|
||||||
|
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: List of tokens with timing information
|
||||||
|
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tokens with speaker assignments
|
||||||
|
Last speaker_segment
|
||||||
|
"""
|
||||||
|
with self.segment_lock:
|
||||||
|
segments = self.speaker_segments.copy()
|
||||||
|
|
||||||
|
if not segments or not tokens:
|
||||||
|
logger.debug("No segments or tokens available for speaker assignment")
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||||
|
use_punctuation_split = False
|
||||||
|
if not use_punctuation_split:
|
||||||
|
# Simple overlap-based assignment
|
||||||
|
for token in tokens:
|
||||||
|
token.speaker = -1 # Default to no speaker
|
||||||
|
for segment in segments:
|
||||||
|
# Check for timing overlap
|
||||||
|
if not (segment.end <= token.start or segment.start >= token.end):
|
||||||
|
token.speaker = segment.speaker + 1 # Convert to 1-based indexing
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Use punctuation-aware assignment (similar to diart_backend)
|
||||||
|
tokens = self._add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _add_speaker_to_tokens_with_punctuation(self, segments: List[SpeakerSegment], tokens: list) -> list:
|
||||||
|
"""
|
||||||
|
Assign speakers to tokens with punctuation-aware boundary adjustment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments: List of speaker segments
|
||||||
|
tokens: List of tokens to assign speakers to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tokens with speaker assignments
|
||||||
|
"""
|
||||||
|
punctuation_marks = {'.', '!', '?'}
|
||||||
|
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||||
|
|
||||||
|
# Convert segments to concatenated format
|
||||||
|
segments_concatenated = self._concatenate_speakers(segments)
|
||||||
|
|
||||||
|
# Adjust segment boundaries based on punctuation
|
||||||
|
for ind, segment in enumerate(segments_concatenated):
|
||||||
|
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||||
|
if punctuation_token.start > segment['end']:
|
||||||
|
after_length = punctuation_token.start - segment['end']
|
||||||
|
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||||
|
|
||||||
|
if before_length > after_length:
|
||||||
|
segment['end'] = punctuation_token.start
|
||||||
|
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||||
|
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||||
|
else:
|
||||||
|
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||||
|
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||||
|
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||||
|
break
|
||||||
|
|
||||||
|
# Ensure non-overlapping tokens
|
||||||
|
last_end = 0.0
|
||||||
|
for token in tokens:
|
||||||
|
start = max(last_end + 0.01, token.start)
|
||||||
|
token.start = start
|
||||||
|
token.end = max(start, token.end)
|
||||||
|
last_end = token.end
|
||||||
|
|
||||||
|
# Assign speakers based on adjusted segments
|
||||||
|
ind_last_speaker = 0
|
||||||
|
for segment in segments_concatenated:
|
||||||
|
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||||
|
if token.end <= segment['end']:
|
||||||
|
token.speaker = segment['speaker']
|
||||||
|
ind_last_speaker = i + 1
|
||||||
|
elif token.start > segment['end']:
|
||||||
|
break
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _concatenate_speakers(self, segments: List[SpeakerSegment]) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Concatenate consecutive segments from the same speaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segments: List of speaker segments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of concatenated speaker segments
|
||||||
|
"""
|
||||||
|
if not segments:
|
||||||
|
return []
|
||||||
|
|
||||||
|
segments_concatenated = [{"speaker": segments[0].speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||||
|
|
||||||
|
for segment in segments[1:]:
|
||||||
|
speaker = segment.speaker + 1
|
||||||
|
if segments_concatenated[-1]['speaker'] != speaker:
|
||||||
|
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||||
|
else:
|
||||||
|
segments_concatenated[-1]['end'] = segment.end
|
||||||
|
|
||||||
|
return segments_concatenated
|
||||||
|
|
||||||
|
def get_segments(self) -> List[SpeakerSegment]:
|
||||||
|
"""Get a copy of the current speaker segments."""
|
||||||
|
with self.segment_lock:
|
||||||
|
return self.speaker_segments.copy()
|
||||||
|
|
||||||
|
def clear_old_segments(self, older_than: float = 30.0):
|
||||||
|
"""Clear segments older than the specified time."""
|
||||||
|
with self.segment_lock:
|
||||||
|
current_time = self.processed_time
|
||||||
|
self.speaker_segments = [
|
||||||
|
segment for segment in self.speaker_segments
|
||||||
|
if current_time - segment.end < older_than
|
||||||
|
]
|
||||||
|
logger.debug(f"Cleared old segments, remaining: {len(self.speaker_segments)}")
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the diarization system and clean up resources."""
|
||||||
|
logger.info("Closing SortformerDiarization")
|
||||||
|
with self.segment_lock:
|
||||||
|
self.speaker_segments.clear()
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||||
|
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
||||||
|
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
||||||
|
wav_file.setnchannels(1) # mono audio
|
||||||
|
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
||||||
|
wav_file.setframerate(self.sample_rate)
|
||||||
|
wav_file.writeframes(audio_data_int16.tobytes())
|
||||||
|
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_number(s: str) -> int:
|
||||||
|
"""Extract number from speaker string (compatibility function)."""
|
||||||
|
import re
|
||||||
|
m = re.search(r'\d+', s)
|
||||||
|
return int(m.group()) if m else 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import asyncio
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""TEST ONLY."""
|
||||||
|
an4_audio = 'audio_test.mp3'
|
||||||
|
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||||
|
signal = signal[:16000*30]
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("ground truth:")
|
||||||
|
print("Speaker 0: 0:00 - 0:09")
|
||||||
|
print("Speaker 1: 0:09 - 0:19")
|
||||||
|
print("Speaker 2: 0:19 - 0:25")
|
||||||
|
print("Speaker 0: 0:25 - 0:30")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
diarization = SortformerDiarization(sample_rate=16000)
|
||||||
|
chunk_size = 1600
|
||||||
|
|
||||||
|
for i in range(0, len(signal), chunk_size):
|
||||||
|
chunk = signal[i:i+chunk_size]
|
||||||
|
await diarization.diarize(chunk)
|
||||||
|
print(f"Processed chunk {i // chunk_size + 1}")
|
||||||
|
|
||||||
|
segments = diarization.get_segments()
|
||||||
|
print("\nDiarization results:")
|
||||||
|
for segment in segments:
|
||||||
|
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
205
whisperlivekit/diarization/sortformer_backend_offline.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||||
|
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
|
||||||
|
diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2")
|
||||||
|
diar_model.eval()
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
diar_model.to(torch.device("cuda"))
|
||||||
|
|
||||||
|
#we target 1 second lag for the moment. chunk_len could be reduced.
|
||||||
|
diar_model.sortformer_modules.chunk_len = 10
|
||||||
|
diar_model.sortformer_modules.subsampling_factor = 10 #8 would be better ideally
|
||||||
|
|
||||||
|
diar_model.sortformer_modules.chunk_right_context = 0 #no.
|
||||||
|
diar_model.sortformer_modules.chunk_left_context = 10 #big so it compensiate the problem with no padding later.
|
||||||
|
|
||||||
|
diar_model.sortformer_modules.spkcache_len = 188
|
||||||
|
diar_model.sortformer_modules.fifo_len = 188
|
||||||
|
diar_model.sortformer_modules.spkcache_update_period = 144
|
||||||
|
diar_model.sortformer_modules.log = False
|
||||||
|
diar_model.sortformer_modules._check_streaming_parameters()
|
||||||
|
|
||||||
|
|
||||||
|
audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||||
|
window_size= 0.025,
|
||||||
|
normalize="NA",
|
||||||
|
n_fft=512,
|
||||||
|
features=128,
|
||||||
|
pad_to=0) #pad_to 16 works better than 0. On test audio, we detect a third speaker for 1 second with pad_to=0. To solve that : increase left context to 10.
|
||||||
|
|
||||||
|
return diar_model, audio2mel
|
||||||
|
|
||||||
|
diar_model, audio2mel = load_model()
|
||||||
|
|
||||||
|
class StreamingSortformerState:
|
||||||
|
"""
|
||||||
|
This class creates a class instance that will be used to store the state of the
|
||||||
|
streaming Sortformer model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||||
|
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||||
|
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||||
|
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||||
|
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||||
|
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||||
|
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||||
|
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||||
|
n_sil_frames (torch.Tensor): Number of silence frames
|
||||||
|
"""
|
||||||
|
|
||||||
|
spkcache = None # Speaker cache to store embeddings from start
|
||||||
|
spkcache_lengths = None #
|
||||||
|
spkcache_preds = None # speaker cache predictions
|
||||||
|
fifo = None # to save the embedding from the latest chunks
|
||||||
|
fifo_lengths = None
|
||||||
|
fifo_preds = None
|
||||||
|
spk_perm = None
|
||||||
|
mean_sil_emb = None
|
||||||
|
n_sil_frames = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_streaming_state(self, batch_size: int = 1, async_streaming: bool = False, device: torch.device = None):
|
||||||
|
"""
|
||||||
|
Initializes StreamingSortformerState with empty tensors or zero-valued tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): Batch size for tensors in streaming state
|
||||||
|
async_streaming (bool): True for asynchronous update, False for synchronous update
|
||||||
|
device (torch.device): Device for tensors in streaming state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
streaming_state (SortformerStreamingState): initialized streaming state
|
||||||
|
"""
|
||||||
|
streaming_state = StreamingSortformerState()
|
||||||
|
if async_streaming:
|
||||||
|
streaming_state.spkcache = torch.zeros((batch_size, self.spkcache_len, self.fc_d_model), device=device)
|
||||||
|
streaming_state.spkcache_preds = torch.zeros((batch_size, self.spkcache_len, self.n_spk), device=device)
|
||||||
|
streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
|
streaming_state.fifo = torch.zeros((batch_size, self.fifo_len, self.fc_d_model), device=device)
|
||||||
|
streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
|
else:
|
||||||
|
streaming_state.spkcache = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||||
|
streaming_state.fifo = torch.zeros((batch_size, 0, self.fc_d_model), device=device)
|
||||||
|
streaming_state.mean_sil_emb = torch.zeros((batch_size, self.fc_d_model), device=device)
|
||||||
|
streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||||
|
return streaming_state
|
||||||
|
|
||||||
|
|
||||||
|
def process_diarization(chunks):
|
||||||
|
"""
|
||||||
|
what it does:
|
||||||
|
1. Preprocessing: Applies dithering and pre-emphasis (high-pass filter) if enabled
|
||||||
|
2. STFT: Computes the Short-Time Fourier Transform using:
|
||||||
|
- the window of window_size=0.025 --> size of a window : 400 samples
|
||||||
|
- the hop parameter : n_window_stride = 0.01 -> every 160 samples, a new window
|
||||||
|
3. Magnitude Calculation: Converts complex STFT output to magnitude spectrogram
|
||||||
|
4. Mel Conversion: Applies Mel filterbanks (128 filters in this case) to get Mel spectrogram
|
||||||
|
5. Logarithm: Takes the log of the Mel spectrogram (if `log=True`)
|
||||||
|
6. Normalization: Skips normalization since `normalize="NA"`
|
||||||
|
7. Padding: Pads the time dimension to a multiple of `pad_to` (default 16)
|
||||||
|
"""
|
||||||
|
previous_chunk = None
|
||||||
|
l_chunk_feat_seq_t = []
|
||||||
|
for chunk in chunks:
|
||||||
|
audio_signal_chunk = torch.tensor(chunk).unsqueeze(0).to(diar_model.device)
|
||||||
|
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]]).to(diar_model.device)
|
||||||
|
processed_signal_chunk, processed_signal_length_chunk = audio2mel.get_features(audio_signal_chunk, audio_signal_length_chunk)
|
||||||
|
if previous_chunk is not None:
|
||||||
|
to_add = previous_chunk[:, :, -99:]
|
||||||
|
total = torch.concat([to_add, processed_signal_chunk], dim=2)
|
||||||
|
else:
|
||||||
|
total = processed_signal_chunk
|
||||||
|
previous_chunk = processed_signal_chunk
|
||||||
|
l_chunk_feat_seq_t.append(torch.transpose(total, 1, 2))
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
streaming_state = init_streaming_state(diar_model.sortformer_modules,
|
||||||
|
batch_size = batch_size,
|
||||||
|
async_streaming = True,
|
||||||
|
device = diar_model.device
|
||||||
|
)
|
||||||
|
total_preds = torch.zeros((batch_size, 0, diar_model.sortformer_modules.n_spk), device=diar_model.device)
|
||||||
|
|
||||||
|
chunk_duration_seconds = diar_model.sortformer_modules.chunk_len * diar_model.sortformer_modules.subsampling_factor * diar_model.preprocessor._cfg.window_stride
|
||||||
|
|
||||||
|
l_speakers = [
|
||||||
|
{'start_time': 0,
|
||||||
|
'end_time': 0,
|
||||||
|
'speaker': 0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
len_prediction = None
|
||||||
|
left_offset = 0
|
||||||
|
right_offset = 8
|
||||||
|
for i, chunk_feat_seq_t in enumerate(l_chunk_feat_seq_t):
|
||||||
|
with torch.inference_mode():
|
||||||
|
streaming_state, total_preds = diar_model.forward_streaming_step(
|
||||||
|
processed_signal=chunk_feat_seq_t,
|
||||||
|
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]),
|
||||||
|
streaming_state=streaming_state,
|
||||||
|
total_preds=total_preds,
|
||||||
|
left_offset=left_offset,
|
||||||
|
right_offset=right_offset,
|
||||||
|
)
|
||||||
|
left_offset = 8
|
||||||
|
preds_np = total_preds[0].cpu().numpy()
|
||||||
|
active_speakers = np.argmax(preds_np, axis=1)
|
||||||
|
if len_prediction is None:
|
||||||
|
len_prediction = len(active_speakers) # we want to get the len of 1 prediction
|
||||||
|
frame_duration = chunk_duration_seconds / len_prediction
|
||||||
|
active_speakers = active_speakers[-len_prediction:]
|
||||||
|
for idx, spk in enumerate(active_speakers):
|
||||||
|
if spk != l_speakers[-1]['speaker']:
|
||||||
|
l_speakers.append(
|
||||||
|
{'start_time': (i * chunk_duration_seconds + idx * frame_duration),
|
||||||
|
'end_time': (i * chunk_duration_seconds + (idx + 1) * frame_duration),
|
||||||
|
'speaker': spk
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
l_speakers[-1]['end_time'] = i * chunk_duration_seconds + (idx + 1) * frame_duration
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Should print
|
||||||
|
[{'start_time': 0, 'end_time': 8.72, 'speaker': 0},
|
||||||
|
{'start_time': 8.72, 'end_time': 18.88, 'speaker': 1},
|
||||||
|
{'start_time': 18.88, 'end_time': 24.96, 'speaker': 2},
|
||||||
|
{'start_time': 24.96, 'end_time': 31.68, 'speaker': 0}]
|
||||||
|
"""
|
||||||
|
for speaker in l_speakers:
|
||||||
|
print(f"Speaker {speaker['speaker']}: {speaker['start_time']:.2f}s - {speaker['end_time']:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
an4_audio = 'audio_test.mp3'
|
||||||
|
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||||
|
signal = signal[:16000*30]
|
||||||
|
# signal = signal[:-(len(signal)%16000)]
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Expected ground truth:")
|
||||||
|
print("Speaker 0: 0:00 - 0:09")
|
||||||
|
print("Speaker 1: 0:09 - 0:19")
|
||||||
|
print("Speaker 2: 0:19 - 0:25")
|
||||||
|
print("Speaker 0: 0:25 - 0:30")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
chunk_size = 16000 # 1 second
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(signal), chunk_size):
|
||||||
|
chunk = signal[i:i+chunk_size]
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
process_diarization(chunks)
|
||||||
197
whisperlivekit/ffmpeg_manager.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Callable
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
ERROR_INSTALL_INSTRUCTIONS = f"""
|
||||||
|
{'='*50}
|
||||||
|
FFmpeg is not installed or not found in your system's PATH.
|
||||||
|
Alternative Solution: You can still use WhisperLiveKit without FFmpeg by adding the --pcm-input parameter. Note that when using this option, audio will not be compressed between the frontend and backend, which may result in higher bandwidth usage.
|
||||||
|
|
||||||
|
If you want to install FFmpeg:
|
||||||
|
|
||||||
|
# Ubuntu/Debian:
|
||||||
|
sudo apt update && sudo apt install ffmpeg
|
||||||
|
|
||||||
|
# macOS (using Homebrew):
|
||||||
|
brew install ffmpeg
|
||||||
|
|
||||||
|
# Windows:
|
||||||
|
# 1. Download the latest static build from https://ffmpeg.org/download.html
|
||||||
|
# 2. Extract the archive (e.g., to C:\\FFmpeg).
|
||||||
|
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
|
||||||
|
|
||||||
|
After installation, please restart the application.
|
||||||
|
{'='*50}
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FFmpegState(Enum):
|
||||||
|
STOPPED = "stopped"
|
||||||
|
STARTING = "starting"
|
||||||
|
RUNNING = "running"
|
||||||
|
RESTARTING = "restarting"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
class FFmpegManager:
|
||||||
|
def __init__(self, sample_rate: int = 16000, channels: int = 1):
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
self.process: Optional[asyncio.subprocess.Process] = None
|
||||||
|
self._stderr_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
self.on_error_callback: Optional[Callable[[str], None]] = None
|
||||||
|
|
||||||
|
self.state = FFmpegState.STOPPED
|
||||||
|
self._state_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def start(self) -> bool:
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state != FFmpegState.STOPPED:
|
||||||
|
logger.warning(f"FFmpeg already running in state: {self.state}")
|
||||||
|
return False
|
||||||
|
self.state = FFmpegState.STARTING
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-hide_banner",
|
||||||
|
"-loglevel", "error",
|
||||||
|
"-i", "pipe:0",
|
||||||
|
"-f", "s16le",
|
||||||
|
"-acodec", "pcm_s16le",
|
||||||
|
"-ac", str(self.channels),
|
||||||
|
"-ar", str(self.sample_rate),
|
||||||
|
"pipe:1"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.process = await asyncio.create_subprocess_exec(
|
||||||
|
*cmd,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE
|
||||||
|
)
|
||||||
|
|
||||||
|
self._stderr_task = asyncio.create_task(self._drain_stderr())
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
self.state = FFmpegState.RUNNING
|
||||||
|
|
||||||
|
logger.info("FFmpeg started.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(ERROR_INSTALL_INSTRUCTIONS)
|
||||||
|
async with self._state_lock:
|
||||||
|
self.state = FFmpegState.FAILED
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("ffmpeg_not_found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting FFmpeg: {e}")
|
||||||
|
async with self._state_lock:
|
||||||
|
self.state = FFmpegState.FAILED
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("start_failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state == FFmpegState.STOPPED:
|
||||||
|
return
|
||||||
|
self.state = FFmpegState.STOPPED
|
||||||
|
|
||||||
|
if self.process:
|
||||||
|
if self.process.stdin and not self.process.stdin.is_closing():
|
||||||
|
self.process.stdin.close()
|
||||||
|
await self.process.stdin.wait_closed()
|
||||||
|
await self.process.wait()
|
||||||
|
self.process = None
|
||||||
|
|
||||||
|
if self._stderr_task:
|
||||||
|
self._stderr_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self._stderr_task
|
||||||
|
|
||||||
|
logger.info("FFmpeg stopped.")
|
||||||
|
|
||||||
|
async def write_data(self, data: bytes) -> bool:
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state != FFmpegState.RUNNING:
|
||||||
|
logger.warning(f"Cannot write, FFmpeg state: {self.state}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.process.stdin.write(data)
|
||||||
|
await self.process.stdin.drain()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing to FFmpeg: {e}")
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("write_error")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def read_data(self, size: int) -> Optional[bytes]:
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state != FFmpegState.RUNNING:
|
||||||
|
logger.warning(f"Cannot read, FFmpeg state: {self.state}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(
|
||||||
|
self.process.stdout.read(size),
|
||||||
|
timeout=20.0
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("FFmpeg read timeout.")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading from FFmpeg: {e}")
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("read_error")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_state(self) -> FFmpegState:
|
||||||
|
async with self._state_lock:
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
async def restart(self) -> bool:
|
||||||
|
async with self._state_lock:
|
||||||
|
if self.state == FFmpegState.RESTARTING:
|
||||||
|
logger.warning("Restart already in progress.")
|
||||||
|
return False
|
||||||
|
self.state = FFmpegState.RESTARTING
|
||||||
|
|
||||||
|
logger.info("Restarting FFmpeg...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.stop()
|
||||||
|
await asyncio.sleep(1) # short delay before restarting
|
||||||
|
return await self.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during FFmpeg restart: {e}")
|
||||||
|
async with self._state_lock:
|
||||||
|
self.state = FFmpegState.FAILED
|
||||||
|
if self.on_error_callback:
|
||||||
|
await self.on_error_callback("restart_failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _drain_stderr(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if not self.process or not self.process.stderr:
|
||||||
|
break
|
||||||
|
line = await self.process.stderr.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
logger.debug(f"FFmpeg stderr: {line.decode(errors='ignore').strip()}")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("FFmpeg stderr drain task cancelled.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error draining FFmpeg stderr: {e}")
|
||||||
@@ -20,7 +20,7 @@ def parse_args():
|
|||||||
help="""
|
help="""
|
||||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||||
If False, no warmup is performed.
|
If empty, no warmup is performed.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -58,12 +58,26 @@ def parse_args():
|
|||||||
help="Hugging Face model ID for pyannote.audio embedding model.",
|
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--diarization-backend",
|
||||||
|
type=str,
|
||||||
|
default="sortformer",
|
||||||
|
choices=["sortformer", "diart"],
|
||||||
|
help="The diarization backend to use.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-transcription",
|
"--no-transcription",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable transcription to only see live diarization results.",
|
help="Disable transcription to only see live diarization results.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-punctuation-split",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable the split parameter.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-chunk-size",
|
"--min-chunk-size",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -74,7 +88,8 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
default="tiny",
|
default="small",
|
||||||
|
dest='model_size',
|
||||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -95,6 +110,7 @@ def parse_args():
|
|||||||
"--language",
|
"--language",
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
|
dest='lan',
|
||||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -104,18 +120,27 @@ def parse_args():
|
|||||||
choices=["transcribe", "translate"],
|
choices=["transcribe", "translate"],
|
||||||
help="Transcribe or translate.",
|
help="Transcribe or translate.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-language",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
dest="target_language",
|
||||||
|
help="Target language for translation. Not functional yet.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
default="faster-whisper",
|
default="simulstreaming",
|
||||||
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api", "simulstreaming"],
|
||||||
help="Load only this backend for Whisper processing.",
|
help="Load only this backend for Whisper processing.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vac",
|
"--no-vac",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
help="Disable VAC = voice activity controller.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||||
@@ -150,9 +175,30 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||||
|
parser.add_argument("--forwarded-allow-ips", type=str, help="Allowed ips for reverse proxying.", default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pcm-input",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
|
||||||
|
)
|
||||||
# SimulStreaming-specific arguments
|
# SimulStreaming-specific arguments
|
||||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||||
|
|
||||||
|
simulstreaming_group.add_argument(
|
||||||
|
"--disable-fast-encoder",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
dest="disable_fast_encoder",
|
||||||
|
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
|
||||||
|
)
|
||||||
|
|
||||||
|
simulstreaming_group.add_argument(
|
||||||
|
"--custom-alignment-heads",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Use your own alignment heads, useful when `--model-dir` is used",
|
||||||
|
)
|
||||||
|
|
||||||
simulstreaming_group.add_argument(
|
simulstreaming_group.add_argument(
|
||||||
"--frame-threshold",
|
"--frame-threshold",
|
||||||
@@ -242,6 +288,28 @@ def parse_args():
|
|||||||
dest="model_path",
|
dest="model_path",
|
||||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
simulstreaming_group.add_argument(
|
||||||
|
"--preload-model-count",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
dest="preload_model_count",
|
||||||
|
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||||
|
)
|
||||||
|
|
||||||
|
simulstreaming_group.add_argument(
|
||||||
|
"--nllb-backend",
|
||||||
|
type=str,
|
||||||
|
default="transformers",
|
||||||
|
help="transformers or ctranslate2",
|
||||||
|
)
|
||||||
|
|
||||||
|
simulstreaming_group.add_argument(
|
||||||
|
"--nllb-size",
|
||||||
|
type=str,
|
||||||
|
default="600M",
|
||||||
|
help="600M or 1.3B",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
106
whisperlivekit/remove_silences.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
from time import time
|
||||||
|
import re
|
||||||
|
|
||||||
|
MIN_SILENCE_DURATION = 4 #in seconds
|
||||||
|
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
|
||||||
|
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
|
||||||
|
|
||||||
|
def blank_to_silence(tokens):
|
||||||
|
full_string = ''.join([t.text for t in tokens])
|
||||||
|
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
|
||||||
|
matches = []
|
||||||
|
for pattern in patterns:
|
||||||
|
for m in pattern.finditer(full_string):
|
||||||
|
matches.append({
|
||||||
|
'start': m.start(),
|
||||||
|
'end': m.end()
|
||||||
|
})
|
||||||
|
if matches:
|
||||||
|
# cleaned = pattern.sub(' ', full_string).strip()
|
||||||
|
# print("Cleaned:", cleaned)
|
||||||
|
cumulated_len = 0
|
||||||
|
silence_token = None
|
||||||
|
cleaned_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if matches:
|
||||||
|
start = cumulated_len
|
||||||
|
end = cumulated_len + len(token.text)
|
||||||
|
cumulated_len = end
|
||||||
|
if start >= matches[0]['start'] and end <= matches[0]['end']:
|
||||||
|
if silence_token: #previous token was already silence
|
||||||
|
silence_token.start = min(silence_token.start, token.start)
|
||||||
|
silence_token.end = max(silence_token.end, token.end)
|
||||||
|
else: #new silence
|
||||||
|
silence_token = ASRToken(
|
||||||
|
start=token.start,
|
||||||
|
end=token.end,
|
||||||
|
speaker=-2,
|
||||||
|
probability=0.95
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if silence_token: #there was silence but no more
|
||||||
|
if silence_token.duration() >= MIN_SILENCE_DURATION:
|
||||||
|
cleaned_tokens.append(
|
||||||
|
silence_token
|
||||||
|
)
|
||||||
|
silence_token = None
|
||||||
|
matches.pop(0)
|
||||||
|
cleaned_tokens.append(token)
|
||||||
|
# print(cleaned_tokens)
|
||||||
|
return cleaned_tokens
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def no_token_to_silence(tokens):
|
||||||
|
new_tokens = []
|
||||||
|
silence_token = None
|
||||||
|
for token in tokens:
|
||||||
|
if token.speaker == -2:
|
||||||
|
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
|
||||||
|
new_tokens[-1].end = token.end
|
||||||
|
else:
|
||||||
|
new_tokens.append(token)
|
||||||
|
|
||||||
|
last_end = new_tokens[-1].end if new_tokens else 0.0
|
||||||
|
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
|
||||||
|
if new_tokens and new_tokens[-1].speaker == -2:
|
||||||
|
new_tokens[-1].end = token.start
|
||||||
|
else:
|
||||||
|
silence_token = ASRToken(
|
||||||
|
start=last_end,
|
||||||
|
end=token.start,
|
||||||
|
speaker=-2,
|
||||||
|
probability=0.95
|
||||||
|
)
|
||||||
|
new_tokens.append(silence_token)
|
||||||
|
|
||||||
|
if token.speaker != -2:
|
||||||
|
new_tokens.append(token)
|
||||||
|
return new_tokens
|
||||||
|
|
||||||
|
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
|
||||||
|
current_time = time() - (beg_loop if beg_loop else 0.0)
|
||||||
|
last_token = tokens[-1]
|
||||||
|
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
||||||
|
if last_token.speaker == -2:
|
||||||
|
last_token.end = current_time
|
||||||
|
else:
|
||||||
|
tokens.append(
|
||||||
|
ASRToken(
|
||||||
|
start=tokens[-1].end,
|
||||||
|
end=current_time,
|
||||||
|
speaker=-2,
|
||||||
|
probability=0.95
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def handle_silences(tokens, beg_loop, vac_detected_silence):
|
||||||
|
if not tokens:
|
||||||
|
return []
|
||||||
|
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||||
|
tokens = no_token_to_silence(tokens)
|
||||||
|
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
|
||||||
|
return tokens
|
||||||
|
|
||||||
154
whisperlivekit/results_formater.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
|
||||||
|
import logging
|
||||||
|
from whisperlivekit.remove_silences import handle_silences
|
||||||
|
from whisperlivekit.timed_objects import Line, format_time
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
CHECK_AROUND = 4
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_punctuation(token):
|
||||||
|
if token.is_punctuation():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def next_punctuation_change(i, tokens):
|
||||||
|
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
|
||||||
|
if is_punctuation(tokens[ind]):
|
||||||
|
return ind
|
||||||
|
return None
|
||||||
|
|
||||||
|
def next_speaker_change(i, tokens, speaker):
|
||||||
|
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
|
||||||
|
token = tokens[ind]
|
||||||
|
if is_punctuation(token):
|
||||||
|
break
|
||||||
|
if token.speaker != speaker:
|
||||||
|
return ind, token.speaker
|
||||||
|
return None, speaker
|
||||||
|
|
||||||
|
def new_line(
|
||||||
|
token,
|
||||||
|
):
|
||||||
|
return Line(
|
||||||
|
speaker = token.corrected_speaker,
|
||||||
|
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
|
||||||
|
start = token.start,
|
||||||
|
end = token.end,
|
||||||
|
detected_language=token.detected_language
|
||||||
|
)
|
||||||
|
|
||||||
|
def append_token_to_last_line(lines, sep, token):
|
||||||
|
if not lines:
|
||||||
|
lines.append(new_line(token))
|
||||||
|
else:
|
||||||
|
if token.text:
|
||||||
|
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
|
||||||
|
lines[-1].end = token.end
|
||||||
|
if not lines[-1].detected_language and token.detected_language:
|
||||||
|
lines[-1].detected_language = token.detected_language
|
||||||
|
|
||||||
|
|
||||||
|
def format_output(state, silence, args, sep):
|
||||||
|
diarization = args.diarization
|
||||||
|
disable_punctuation_split = args.disable_punctuation_split
|
||||||
|
tokens = state.tokens
|
||||||
|
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||||
|
translation_buffer = state.translation_buffer
|
||||||
|
last_validated_token = state.last_validated_token
|
||||||
|
|
||||||
|
previous_speaker = 1
|
||||||
|
undiarized_text = []
|
||||||
|
tokens = handle_silences(tokens, state.beg_loop, silence)
|
||||||
|
last_punctuation = None
|
||||||
|
for i, token in enumerate(tokens[last_validated_token:]):
|
||||||
|
speaker = int(token.speaker)
|
||||||
|
token.corrected_speaker = speaker
|
||||||
|
if not diarization:
|
||||||
|
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||||
|
token.corrected_speaker = 1
|
||||||
|
token.validated_speaker = True
|
||||||
|
else:
|
||||||
|
if is_punctuation(token):
|
||||||
|
last_punctuation = i
|
||||||
|
|
||||||
|
if last_punctuation == i-1:
|
||||||
|
if token.speaker != previous_speaker:
|
||||||
|
token.validated_speaker = True
|
||||||
|
# perfect, diarization perfectly aligned
|
||||||
|
last_punctuation = None
|
||||||
|
else:
|
||||||
|
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||||
|
if speaker_change_pos:
|
||||||
|
# Corrects delay:
|
||||||
|
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
||||||
|
# should become:
|
||||||
|
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
||||||
|
token.corrected_speaker = new_speaker
|
||||||
|
token.validated_speaker = True
|
||||||
|
elif speaker != previous_speaker:
|
||||||
|
if not (speaker == -2 or previous_speaker == -2):
|
||||||
|
if next_punctuation_change(i, tokens):
|
||||||
|
# Corrects advance:
|
||||||
|
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
||||||
|
# should become:
|
||||||
|
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||||
|
token.corrected_speaker = previous_speaker
|
||||||
|
token.validated_speaker = True
|
||||||
|
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
||||||
|
if not disable_punctuation_split:
|
||||||
|
token.corrected_speaker = previous_speaker
|
||||||
|
token.validated_speaker = False
|
||||||
|
if token.validated_speaker:
|
||||||
|
state.last_validated_token = i
|
||||||
|
previous_speaker = token.corrected_speaker
|
||||||
|
|
||||||
|
previous_speaker = 1
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for token in tokens:
|
||||||
|
if int(token.corrected_speaker) != int(previous_speaker):
|
||||||
|
lines.append(new_line(token))
|
||||||
|
else:
|
||||||
|
append_token_to_last_line(lines, sep, token)
|
||||||
|
|
||||||
|
previous_speaker = token.corrected_speaker
|
||||||
|
|
||||||
|
if lines:
|
||||||
|
unassigned_translated_segments = []
|
||||||
|
for ts in translation_validated_segments:
|
||||||
|
assigned = False
|
||||||
|
for line in lines:
|
||||||
|
if ts and ts.overlaps_with(line):
|
||||||
|
if ts.is_within(line):
|
||||||
|
line.translation += ts.text + ' '
|
||||||
|
assigned = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
ts0, ts1 = ts.approximate_cut_at(line.end)
|
||||||
|
if ts0 and line.overlaps_with(ts0):
|
||||||
|
line.translation += ts0.text + ' '
|
||||||
|
if ts1:
|
||||||
|
unassigned_translated_segments.append(ts1)
|
||||||
|
assigned = True
|
||||||
|
break
|
||||||
|
if not assigned:
|
||||||
|
unassigned_translated_segments.append(ts)
|
||||||
|
|
||||||
|
if unassigned_translated_segments:
|
||||||
|
for line in lines:
|
||||||
|
remaining_segments = []
|
||||||
|
for ts in unassigned_translated_segments:
|
||||||
|
if ts and ts.overlaps_with(line):
|
||||||
|
line.translation += ts.text + ' '
|
||||||
|
else:
|
||||||
|
remaining_segments.append(ts)
|
||||||
|
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||||
|
|
||||||
|
if state.buffer_transcription and lines:
|
||||||
|
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
||||||
|
|
||||||
|
return lines, undiarized_text
|
||||||
294
whisperlivekit/silero_vad_iterator.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
"""
|
||||||
|
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||||
|
"""
|
||||||
|
|
||||||
|
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||||
|
"""Load a JIT model from file."""
|
||||||
|
model = torch.jit.load(model_path, map_location=device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxWrapper():
|
||||||
|
"""ONNX Runtime wrapper for Silero VAD model."""
|
||||||
|
|
||||||
|
def __init__(self, path, force_onnx_cpu=False):
|
||||||
|
global np
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime
|
||||||
|
|
||||||
|
opts = onnxruntime.SessionOptions()
|
||||||
|
opts.inter_op_num_threads = 1
|
||||||
|
opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||||
|
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
||||||
|
else:
|
||||||
|
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||||
|
|
||||||
|
self.reset_states()
|
||||||
|
if '16k' in path:
|
||||||
|
warnings.warn('This model support only 16000 sampling rate!')
|
||||||
|
self.sample_rates = [16000]
|
||||||
|
else:
|
||||||
|
self.sample_rates = [8000, 16000]
|
||||||
|
|
||||||
|
def _validate_input(self, x, sr: int):
|
||||||
|
if x.dim() == 1:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
if x.dim() > 2:
|
||||||
|
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||||
|
|
||||||
|
if sr != 16000 and (sr % 16000 == 0):
|
||||||
|
step = sr // 16000
|
||||||
|
x = x[:,::step]
|
||||||
|
sr = 16000
|
||||||
|
|
||||||
|
if sr not in self.sample_rates:
|
||||||
|
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||||
|
if sr / x.shape[1] > 31.25:
|
||||||
|
raise ValueError("Input audio chunk is too short")
|
||||||
|
|
||||||
|
return x, sr
|
||||||
|
|
||||||
|
def reset_states(self, batch_size=1):
|
||||||
|
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||||
|
self._context = torch.zeros(0)
|
||||||
|
self._last_sr = 0
|
||||||
|
self._last_batch_size = 0
|
||||||
|
|
||||||
|
def __call__(self, x, sr: int):
|
||||||
|
|
||||||
|
x, sr = self._validate_input(x, sr)
|
||||||
|
num_samples = 512 if sr == 16000 else 256
|
||||||
|
|
||||||
|
if x.shape[-1] != num_samples:
|
||||||
|
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||||
|
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
context_size = 64 if sr == 16000 else 32
|
||||||
|
|
||||||
|
if not self._last_batch_size:
|
||||||
|
self.reset_states(batch_size)
|
||||||
|
if (self._last_sr) and (self._last_sr != sr):
|
||||||
|
self.reset_states(batch_size)
|
||||||
|
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||||
|
self.reset_states(batch_size)
|
||||||
|
|
||||||
|
if not len(self._context):
|
||||||
|
self._context = torch.zeros(batch_size, context_size)
|
||||||
|
|
||||||
|
x = torch.cat([self._context, x], dim=1)
|
||||||
|
if sr in [8000, 16000]:
|
||||||
|
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
||||||
|
ort_outs = self.session.run(None, ort_inputs)
|
||||||
|
out, state = ort_outs
|
||||||
|
self._state = torch.from_numpy(state)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
self._context = x[..., -context_size:]
|
||||||
|
self._last_sr = sr
|
||||||
|
self._last_batch_size = batch_size
|
||||||
|
|
||||||
|
out = torch.from_numpy(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
|
||||||
|
"""
|
||||||
|
Load Silero VAD model (JIT or ONNX).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_path : str, optional
|
||||||
|
Path to model file. If None, uses default bundled model.
|
||||||
|
onnx : bool, default False
|
||||||
|
Whether to use ONNX runtime (requires onnxruntime package).
|
||||||
|
opset_version : int, default 16
|
||||||
|
ONNX opset version (15 or 16). Only used if onnx=True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model
|
||||||
|
Loaded VAD model (JIT or ONNX wrapper)
|
||||||
|
"""
|
||||||
|
available_ops = [15, 16]
|
||||||
|
if onnx and opset_version not in available_ops:
|
||||||
|
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||||
|
if model_path is None:
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
data_dir = current_dir / 'vad_models'
|
||||||
|
|
||||||
|
if onnx:
|
||||||
|
if opset_version == 16:
|
||||||
|
model_name = 'silero_vad.onnx'
|
||||||
|
else:
|
||||||
|
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||||
|
else:
|
||||||
|
model_name = 'silero_vad.jit'
|
||||||
|
|
||||||
|
model_path = data_dir / model_name
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Model file not found: {model_path}\n"
|
||||||
|
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = Path(model_path)
|
||||||
|
if onnx:
|
||||||
|
try:
|
||||||
|
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"ONNX runtime not available. Install with: pip install onnxruntime\n"
|
||||||
|
"Or use JIT model by setting onnx=False"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = init_jit_model(str(model_path))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class VADIterator:
|
||||||
|
"""
|
||||||
|
Voice Activity Detection iterator for streaming audio.
|
||||||
|
|
||||||
|
This is the Silero VAD v6 implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
sampling_rate: int = 16000,
|
||||||
|
min_silence_duration_ms: int = 100,
|
||||||
|
speech_pad_ms: int = 30
|
||||||
|
):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Class for stream imitation
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model: preloaded .jit/.onnx silero VAD model
|
||||||
|
|
||||||
|
threshold: float (default - 0.5)
|
||||||
|
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||||
|
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||||
|
|
||||||
|
sampling_rate: int (default - 16000)
|
||||||
|
Currently silero VAD models support 8000 and 16000 sample rates
|
||||||
|
|
||||||
|
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||||
|
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||||
|
|
||||||
|
speech_pad_ms: int (default - 30 milliseconds)
|
||||||
|
Final speech chunks are padded by speech_pad_ms each side
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.threshold = threshold
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
|
||||||
|
if sampling_rate not in [8000, 16000]:
|
||||||
|
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
||||||
|
|
||||||
|
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||||
|
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||||
|
self.reset_states()
|
||||||
|
|
||||||
|
def reset_states(self):
|
||||||
|
|
||||||
|
self.model.reset_states()
|
||||||
|
self.triggered = False
|
||||||
|
self.temp_end = 0
|
||||||
|
self.current_sample = 0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
|
||||||
|
"""
|
||||||
|
x: torch.Tensor
|
||||||
|
audio chunk (see examples in repo)
|
||||||
|
|
||||||
|
return_seconds: bool (default - False)
|
||||||
|
whether return timestamps in seconds (default - samples)
|
||||||
|
|
||||||
|
time_resolution: int (default - 1)
|
||||||
|
time resolution of speech coordinates when requested as seconds
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not torch.is_tensor(x):
|
||||||
|
try:
|
||||||
|
x = torch.Tensor(x)
|
||||||
|
except:
|
||||||
|
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
||||||
|
|
||||||
|
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
||||||
|
self.current_sample += window_size_samples
|
||||||
|
|
||||||
|
speech_prob = self.model(x, self.sampling_rate).item()
|
||||||
|
|
||||||
|
if (speech_prob >= self.threshold) and self.temp_end:
|
||||||
|
self.temp_end = 0
|
||||||
|
|
||||||
|
if (speech_prob >= self.threshold) and not self.triggered:
|
||||||
|
self.triggered = True
|
||||||
|
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||||
|
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
|
||||||
|
|
||||||
|
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||||
|
if not self.temp_end:
|
||||||
|
self.temp_end = self.current_sample
|
||||||
|
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||||
|
self.temp_end = 0
|
||||||
|
self.triggered = False
|
||||||
|
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FixedVADIterator(VADIterator):
|
||||||
|
"""
|
||||||
|
Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reset_states(self):
|
||||||
|
super().reset_states()
|
||||||
|
self.buffer = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
def __call__(self, x, return_seconds=False):
|
||||||
|
self.buffer = np.append(self.buffer, x)
|
||||||
|
ret = None
|
||||||
|
while len(self.buffer) >= 512:
|
||||||
|
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
|
||||||
|
self.buffer = self.buffer[512:]
|
||||||
|
if ret is None:
|
||||||
|
ret = r
|
||||||
|
elif r is not None:
|
||||||
|
if "end" in r:
|
||||||
|
ret["end"] = r["end"]
|
||||||
|
if "start" in r and "end" in ret:
|
||||||
|
del ret["end"]
|
||||||
|
return ret if ret != {} else None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = load_silero_vad(onnx=False)
|
||||||
|
vad = FixedVADIterator(model)
|
||||||
|
|
||||||
|
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||||
|
result = vad(audio_buffer)
|
||||||
|
print(f" 512 samples: {result}")
|
||||||
|
|
||||||
|
# test with 511 samples
|
||||||
|
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||||
|
result = vad(audio_buffer)
|
||||||
6
whisperlivekit/simul_whisper/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SimulStreamingASR",
|
||||||
|
"SimulStreamingOnlineProcessor",
|
||||||
|
]
|
||||||
303
whisperlivekit/simul_whisper/backend.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
import logging
|
||||||
|
import platform
|
||||||
|
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||||
|
from whisperlivekit.warmup import load_file
|
||||||
|
from .whisper import load_model, tokenizer
|
||||||
|
from .whisper.audio import TOKENS_PER_SECOND
|
||||||
|
import os
|
||||||
|
import gc
|
||||||
|
from pathlib import Path
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||||
|
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper
|
||||||
|
from whisperlivekit.simul_whisper.whisper import tokenizer
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||||
|
HAS_MLX_WHISPER = True
|
||||||
|
except ImportError:
|
||||||
|
if platform.system() == "Darwin" and platform.machine() == "arm64":
|
||||||
|
print(f"""{"="*50}\nMLX Whisper not found but you are on Apple Silicon. Consider installing mlx-whisper for better performance: pip install mlx-whisper\n{"="*50}""")
|
||||||
|
HAS_MLX_WHISPER = False
|
||||||
|
if HAS_MLX_WHISPER:
|
||||||
|
HAS_FASTER_WHISPER = False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
HAS_FASTER_WHISPER = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_FASTER_WHISPER = False
|
||||||
|
|
||||||
|
def model_path_and_type(model_path):
|
||||||
|
path = Path(model_path)
|
||||||
|
|
||||||
|
compatible_whisper_mlx = False
|
||||||
|
compatible_faster_whisper = False
|
||||||
|
pt_path = path if path.is_file() and path.suffix.lower() == '.pt' else None
|
||||||
|
|
||||||
|
if path.is_dir():
|
||||||
|
for file in path.iterdir():
|
||||||
|
if file.is_file():
|
||||||
|
if file.name in ['weights.npz', "weights.safetensors"]:
|
||||||
|
compatible_whisper_mlx = True
|
||||||
|
elif file.suffix.lower() == '.bin':
|
||||||
|
compatible_faster_whisper = True
|
||||||
|
elif file.suffix.lower() == '.pt':
|
||||||
|
pt_path = file
|
||||||
|
return pt_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||||
|
|
||||||
|
|
||||||
|
class SimulStreamingOnlineProcessor:
|
||||||
|
SAMPLING_RATE = 16000
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
asr,
|
||||||
|
logfile=sys.stderr,
|
||||||
|
):
|
||||||
|
self.asr = asr
|
||||||
|
self.logfile = logfile
|
||||||
|
self.end = 0.0
|
||||||
|
self.buffer = []
|
||||||
|
self.committed: List[ASRToken] = []
|
||||||
|
self.last_result_tokens: List[ASRToken] = []
|
||||||
|
self.load_new_backend()
|
||||||
|
|
||||||
|
#can be moved
|
||||||
|
if asr.tokenizer:
|
||||||
|
self.model.tokenizer = asr.tokenizer
|
||||||
|
|
||||||
|
def load_new_backend(self):
|
||||||
|
model = self.asr.get_new_model_instance()
|
||||||
|
self.model = PaddedAlignAttWhisper(
|
||||||
|
cfg=self.asr.cfg,
|
||||||
|
loaded_model=model,
|
||||||
|
mlx_encoder=self.asr.mlx_encoder,
|
||||||
|
fw_encoder=self.asr.fw_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
def insert_silence(self, silence_duration, offset):
|
||||||
|
"""
|
||||||
|
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||||
|
"""
|
||||||
|
if silence_duration < 5:
|
||||||
|
gap_silence = torch.zeros(int(16000*silence_duration))
|
||||||
|
self.model.insert_audio(gap_silence)
|
||||||
|
# self.global_time_offset += silence_duration
|
||||||
|
else:
|
||||||
|
self.process_iter(is_last=True) #we want to totally process what remains in the buffer.
|
||||||
|
self.model.refresh_segment(complete=True)
|
||||||
|
self.model.global_time_offset = silence_duration + offset
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||||
|
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||||
|
|
||||||
|
# Convert numpy array to torch tensor
|
||||||
|
audio_tensor = torch.from_numpy(audio).float()
|
||||||
|
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
||||||
|
self.model.insert_audio(audio_tensor)
|
||||||
|
|
||||||
|
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||||
|
self.process_iter(is_last=True)
|
||||||
|
self.model.refresh_segment(complete=True)
|
||||||
|
self.model.speaker = change_speaker.speaker
|
||||||
|
self.global_time_offset = change_speaker.start
|
||||||
|
|
||||||
|
def get_buffer(self):
|
||||||
|
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||||
|
return concat_buffer
|
||||||
|
|
||||||
|
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||||
|
"""
|
||||||
|
Process accumulated audio chunks using SimulStreaming.
|
||||||
|
|
||||||
|
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
timestamped_words = self.model.infer(is_last=is_last)
|
||||||
|
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
|
||||||
|
self.buffer.extend(timestamped_words)
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
self.committed.extend(timestamped_words)
|
||||||
|
self.buffer = []
|
||||||
|
return timestamped_words, self.end
|
||||||
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"SimulStreaming processing error: {e}")
|
||||||
|
return [], self.end
|
||||||
|
|
||||||
|
def warmup(self, audio, init_prompt=""):
|
||||||
|
"""Warmup the SimulStreaming model."""
|
||||||
|
try:
|
||||||
|
self.model.insert_audio(audio)
|
||||||
|
self.model.infer(True)
|
||||||
|
self.model.refresh_segment(complete=True)
|
||||||
|
logger.info("SimulStreaming model warmed up successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"SimulStreaming warmup failed: {e}")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
# free the model and add a new model to stack.
|
||||||
|
# del self.model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# self.asr.new_model_to_stack()
|
||||||
|
self.model.remove_hooks()
|
||||||
|
|
||||||
|
class SimulStreamingASR():
|
||||||
|
"""SimulStreaming backend with AlignAtt policy."""
|
||||||
|
sep = ""
|
||||||
|
|
||||||
|
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||||
|
self.logfile = logfile
|
||||||
|
self.transcribe_kargs = {}
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
if self.decoder_type is None:
|
||||||
|
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||||
|
|
||||||
|
self.fast_encoder = False
|
||||||
|
|
||||||
|
pt_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||||
|
if self.model_path:
|
||||||
|
pt_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(self.model_path)
|
||||||
|
|
||||||
|
elif self.model_size is not None:
|
||||||
|
model_mapping = {
|
||||||
|
'tiny': './tiny.pt',
|
||||||
|
'base': './base.pt',
|
||||||
|
'small': './small.pt',
|
||||||
|
'medium': './medium.pt',
|
||||||
|
'medium.en': './medium.en.pt',
|
||||||
|
'large-v1': './large-v1.pt',
|
||||||
|
'base.en': './base.en.pt',
|
||||||
|
'small.en': './small.en.pt',
|
||||||
|
'tiny.en': './tiny.en.pt',
|
||||||
|
'large-v2': './large-v2.pt',
|
||||||
|
'large-v3': './large-v3.pt',
|
||||||
|
'large': './large-v3.pt'
|
||||||
|
}
|
||||||
|
pt_path = Path(model_mapping.get(self.model_size, f'./{self.model_size}.pt'))
|
||||||
|
|
||||||
|
self.model_name = pt_path.name.replace(".pt", "")
|
||||||
|
|
||||||
|
self.cfg = AlignAttConfig(
|
||||||
|
tokenizer_is_multilingual= not self.model_name.endswith(".en"),
|
||||||
|
segment_length=self.min_chunk_size,
|
||||||
|
frame_threshold=self.frame_threshold,
|
||||||
|
language=self.lan,
|
||||||
|
audio_max_len=self.audio_max_len,
|
||||||
|
audio_min_len=self.audio_min_len,
|
||||||
|
cif_ckpt_path=self.cif_ckpt_path,
|
||||||
|
decoder_type="beam",
|
||||||
|
beam_size=self.beams,
|
||||||
|
task=self.task,
|
||||||
|
never_fire=self.never_fire,
|
||||||
|
init_prompt=self.init_prompt,
|
||||||
|
max_context_tokens=self.max_context_tokens,
|
||||||
|
static_init_prompt=self.static_init_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up tokenizer for translation if needed
|
||||||
|
if self.task == "translate":
|
||||||
|
self.tokenizer = self.set_translate_task()
|
||||||
|
else:
|
||||||
|
self.tokenizer = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.mlx_encoder, self.fw_encoder = None, None
|
||||||
|
if not self.disable_fast_encoder:
|
||||||
|
if HAS_MLX_WHISPER:
|
||||||
|
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||||
|
if self.model_path and compatible_whisper_mlx:
|
||||||
|
mlx_model = self.model_path
|
||||||
|
else:
|
||||||
|
mlx_model = mlx_model_mapping[self.model_name]
|
||||||
|
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||||
|
self.fast_encoder = True
|
||||||
|
elif HAS_FASTER_WHISPER and compatible_faster_whisper:
|
||||||
|
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||||
|
if self.model_path and compatible_faster_whisper:
|
||||||
|
fw_model = self.model_path
|
||||||
|
else:
|
||||||
|
fw_model = self.model_name
|
||||||
|
self.fw_encoder = WhisperModel(
|
||||||
|
fw_model,
|
||||||
|
device='auto',
|
||||||
|
compute_type='auto',
|
||||||
|
)
|
||||||
|
self.fast_encoder = True
|
||||||
|
|
||||||
|
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
whisper_model = load_model(
|
||||||
|
name=self.model_path if self.model_path else self.model_name,
|
||||||
|
download_root=self.model_path,
|
||||||
|
decoder_only=self.fast_encoder,
|
||||||
|
custom_alignment_heads=self.custom_alignment_heads
|
||||||
|
)
|
||||||
|
warmup_audio = load_file(self.warmup_file)
|
||||||
|
if warmup_audio is not None:
|
||||||
|
warmup_audio = torch.from_numpy(warmup_audio).float()
|
||||||
|
if self.fast_encoder:
|
||||||
|
temp_model = PaddedAlignAttWhisper(
|
||||||
|
cfg=self.cfg,
|
||||||
|
loaded_model=whisper_model,
|
||||||
|
mlx_encoder=self.mlx_encoder,
|
||||||
|
fw_encoder=self.fw_encoder,
|
||||||
|
)
|
||||||
|
temp_model.warmup(warmup_audio)
|
||||||
|
temp_model.remove_hooks()
|
||||||
|
else:
|
||||||
|
# For standard encoder, use the original transcribe warmup
|
||||||
|
warmup_audio = load_file(self.warmup_file)
|
||||||
|
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
|
||||||
|
return whisper_model
|
||||||
|
|
||||||
|
def get_new_model_instance(self):
|
||||||
|
"""
|
||||||
|
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
|
||||||
|
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
|
||||||
|
"""
|
||||||
|
if len(self.models) == 0:
|
||||||
|
self.models.append(self.load_model())
|
||||||
|
new_model = self.models.pop()
|
||||||
|
return new_model
|
||||||
|
# self.models[0]
|
||||||
|
|
||||||
|
def new_model_to_stack(self):
|
||||||
|
self.models.append(self.load_model())
|
||||||
|
|
||||||
|
|
||||||
|
def set_translate_task(self):
|
||||||
|
"""Set up translation task."""
|
||||||
|
if self.cfg.language == 'auto':
|
||||||
|
raise Exception('Translation cannot be done with language = auto')
|
||||||
|
return tokenizer.get_tokenizer(
|
||||||
|
multilingual=True,
|
||||||
|
language=self.cfg.language,
|
||||||
|
num_languages=99,
|
||||||
|
task="translate"
|
||||||
|
)
|
||||||
|
|
||||||
|
def transcribe(self, audio):
|
||||||
|
"""
|
||||||
|
Warmup is done directly in load_model
|
||||||
|
"""
|
||||||
|
pass
|
||||||
17
whisperlivekit/simul_whisper/beam.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from .whisper.decoding import PyTorchInference
|
||||||
|
|
||||||
|
# extention of PyTorchInference for beam search
|
||||||
|
class BeamPyTorchInference(PyTorchInference):
|
||||||
|
|
||||||
|
def _kv_modules(self):
|
||||||
|
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
|
||||||
|
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
|
||||||
|
return key_modules + value_modules
|
||||||
|
|
||||||
|
def rearrange_kv_cache(self, source_indices):
|
||||||
|
if source_indices != list(range(len(source_indices))):
|
||||||
|
for module_cache_id in self._kv_modules():
|
||||||
|
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
|
||||||
|
from torch import Tensor
|
||||||
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
|
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||||
25
whisperlivekit/simul_whisper/config.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AlignAttConfig():
|
||||||
|
eval_data_path: str = "tmp"
|
||||||
|
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
||||||
|
frame_threshold: int = 4
|
||||||
|
rewind_threshold: int = 200
|
||||||
|
audio_max_len: float = 20.0
|
||||||
|
cif_ckpt_path: str = ""
|
||||||
|
never_fire: bool = False
|
||||||
|
language: str = field(default="zh")
|
||||||
|
nonspeech_prob: float = 0.5
|
||||||
|
audio_min_len: float = 1.0
|
||||||
|
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||||
|
beam_size: int = 5
|
||||||
|
task: Literal["transcribe","translate"] = "transcribe"
|
||||||
|
tokenizer_is_multilingual: bool = False
|
||||||
|
init_prompt: str = field(default=None)
|
||||||
|
static_init_prompt: str = field(default=None)
|
||||||
|
max_context_tokens: int = field(default=None)
|
||||||
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
📄 SimulStreaming (https://github.com/ufal/SimulStreaming) Licence
|
|
||||||
|
|
||||||
SimulStreaming is dual-licensed:
|
|
||||||
|
|
||||||
🔹 Non-Commercial Use
|
|
||||||
|
|
||||||
You may use SimulStreaming under the **PolyForm Noncommercial License 1.0.0** if you
|
|
||||||
obtain the code through the GitHub repository. This license is **free of charge**
|
|
||||||
and comes with **no obligations** for non-commercial users.
|
|
||||||
|
|
||||||
🔸 Commercial Use
|
|
||||||
|
|
||||||
Understanding who uses SimulStreaming commercially helps us improve and
|
|
||||||
prioritize development. Therefore, we want to **require registration** of those who acquire a commercial licence.
|
|
||||||
|
|
||||||
We plan to make the commercial licenceses **affordable** to SMEs and individuals. We
|
|
||||||
are considering to provide commercial licenses either for free or for symbolic
|
|
||||||
one-time fee, and maybe also provide additional support. You can share your preference via the [questionnaire](https://forms.cloud.microsoft/e/7tCxb4gJfB).
|
|
||||||
|
|
||||||
You can also leave your contact [there](https://forms.cloud.microsoft/e/7tCxb4gJfB) to be notified when the commercial licenses become
|
|
||||||
available.
|
|
||||||
|
|
||||||
✉️ Contact
|
|
||||||
|
|
||||||
[Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek/), machacek@ufal.mff.cuni.cz
|
|
||||||
65
whisperlivekit/simul_whisper/eow_detection.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
# code for the end-of-word detection based on the CIF model proposed in Simul-Whisper
|
||||||
|
|
||||||
|
def load_cif(cfg, n_audio_state, device):
|
||||||
|
"""cfg: AlignAttConfig, n_audio_state: int, device: torch.device"""
|
||||||
|
cif_linear = torch.nn.Linear(n_audio_state, 1)
|
||||||
|
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
|
||||||
|
if cfg.never_fire:
|
||||||
|
never_fire = True
|
||||||
|
always_fire = False
|
||||||
|
else:
|
||||||
|
always_fire = True
|
||||||
|
never_fire = False
|
||||||
|
else:
|
||||||
|
always_fire = False
|
||||||
|
never_fire = cfg.never_fire
|
||||||
|
checkpoint = torch.load(cfg.cif_ckpt_path)
|
||||||
|
cif_linear.load_state_dict(checkpoint)
|
||||||
|
cif_linear.to(device)
|
||||||
|
return cif_linear, always_fire, never_fire
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/dqqcasia/mosst/blob/master/fairseq/models/speech_to_text/convtransformer_wav2vec_cif.py
|
||||||
|
def resize(alphas, target_lengths, threshold=0.999):
|
||||||
|
"""
|
||||||
|
alpha in thresh=1.0 | (0.0, +0.21)
|
||||||
|
target_lengths: if None, apply round and resize, else apply scaling
|
||||||
|
"""
|
||||||
|
# sum
|
||||||
|
_num = alphas.sum(-1)
|
||||||
|
num = target_lengths.float()
|
||||||
|
# scaling
|
||||||
|
_alphas = alphas * (num / _num)[:, None].repeat(1, alphas.size(1))
|
||||||
|
# rm attention value that exceeds threashold
|
||||||
|
count = 0
|
||||||
|
while len(torch.where(_alphas > threshold)[0]):
|
||||||
|
count += 1
|
||||||
|
if count > 10:
|
||||||
|
break
|
||||||
|
xs, ys = torch.where(_alphas > threshold)
|
||||||
|
for x, y in zip(xs, ys):
|
||||||
|
if _alphas[x][y] >= threshold:
|
||||||
|
mask = _alphas[x].ne(0).float()
|
||||||
|
mean = 0.5 * _alphas[x].sum() / mask.sum()
|
||||||
|
_alphas[x] = _alphas[x] * 0.5 + mean * mask
|
||||||
|
|
||||||
|
return _alphas, _num
|
||||||
|
|
||||||
|
def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
|
||||||
|
content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
|
||||||
|
alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
|
||||||
|
alphas = torch.sigmoid(alphas)
|
||||||
|
decode_length = torch.round(alphas.sum(-1)).int()
|
||||||
|
alphas, _ = resize(alphas, decode_length)
|
||||||
|
alphas = alphas.squeeze(0) # (T, )
|
||||||
|
threshold = 0.999
|
||||||
|
integrate = torch.cumsum(alphas[:-1], dim=0) # ignore the peak value at the end of the content chunk
|
||||||
|
exceed_count = integrate[-1] // threshold
|
||||||
|
integrate = integrate - exceed_count*1.0 # minus 1 every time intergrate exceed the threshold
|
||||||
|
important_positions = (integrate >= 0).nonzero(as_tuple=True)[0]
|
||||||
|
if important_positions.numel() == 0:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return important_positions[0] >= content_mel_len-2
|
||||||
43
whisperlivekit/simul_whisper/generation_progress.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
class Tokens:
|
||||||
|
def __init__(self, tokens):
|
||||||
|
self.tokens = tokens
|
||||||
|
|
||||||
|
# def clone(self):
|
||||||
|
# return Tokens(self.tokens.clone())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.tokens.tolist())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
class BeamTokens(Tokens):
|
||||||
|
def __init__(self, tokens, beam_size):
|
||||||
|
self.tokens = tokens
|
||||||
|
self.beam_size = beam_size
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
return BeamTokens(self.tokens.clone())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"BeamTokens({self.tokens.tolist()}, beam_size={self.beam_size})"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
def as_text(self, tokenizer):
|
||||||
|
return tokenizer.decode(self.tokens)
|
||||||
|
|
||||||
|
class Logits(Tokens):
|
||||||
|
def __init__(self, logits):
|
||||||
|
super().__init__(logits)
|
||||||
|
|
||||||
|
# def clone(self):
|
||||||
|
# return Logits(self.tokens.clone(), self.beam_size)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
# return "abc"
|
||||||
|
return f"Logits({self.tokens.shape})"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
72
whisperlivekit/simul_whisper/mlx_encoder.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from mlx.utils import tree_unflatten
|
||||||
|
|
||||||
|
from mlx_whisper import whisper
|
||||||
|
|
||||||
|
mlx_model_mapping = {
|
||||||
|
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||||
|
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||||
|
"base.en": "mlx-community/whisper-base.en-mlx",
|
||||||
|
"base": "mlx-community/whisper-base-mlx",
|
||||||
|
"small.en": "mlx-community/whisper-small.en-mlx",
|
||||||
|
"small": "mlx-community/whisper-small-mlx",
|
||||||
|
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
||||||
|
"medium": "mlx-community/whisper-medium-mlx",
|
||||||
|
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
||||||
|
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
||||||
|
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
||||||
|
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||||
|
"large": "mlx-community/whisper-large-mlx",
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_mlx_encoder(
|
||||||
|
path_or_hf_repo: str,
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
|
) -> whisper.Whisper:
|
||||||
|
model_path = Path(path_or_hf_repo)
|
||||||
|
if not model_path.exists():
|
||||||
|
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
||||||
|
|
||||||
|
with open(str(model_path / "config.json"), "r") as f:
|
||||||
|
config = json.loads(f.read())
|
||||||
|
config.pop("model_type", None)
|
||||||
|
quantization = config.pop("quantization", None)
|
||||||
|
|
||||||
|
model_args = whisper.ModelDimensions(**config)
|
||||||
|
|
||||||
|
wf = model_path / "weights.safetensors"
|
||||||
|
if not wf.exists():
|
||||||
|
wf = model_path / "weights.npz"
|
||||||
|
weights = mx.load(str(wf))
|
||||||
|
|
||||||
|
model = whisper.Whisper(model_args, dtype)
|
||||||
|
|
||||||
|
if quantization is not None:
|
||||||
|
class_predicate = (
|
||||||
|
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||||
|
and f"{p}.scales" in weights
|
||||||
|
)
|
||||||
|
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||||
|
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
|
|
||||||
|
# we only want to load the encoder weights here.
|
||||||
|
# Size examples: for tiny.en,
|
||||||
|
# Decoder weights: 59110771 bytes
|
||||||
|
# Encoder weights: 15268874 bytes
|
||||||
|
|
||||||
|
|
||||||
|
encoder_weights = {}
|
||||||
|
encoder_weights['encoder'] = weights['encoder']
|
||||||
|
del(weights)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model.update(encoder_weights)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
return model
|
||||||
635
whisperlivekit/simul_whisper/simul_whisper.py
Normal file
@@ -0,0 +1,635 @@
|
|||||||
|
# This code was originally in simul_whisper/transcriber/simul_whisper.py . It is adapted a lot for SimulStreaming.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .whisper import load_model, DecodingOptions, tokenizer
|
||||||
|
from .config import AlignAttConfig
|
||||||
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
from .whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||||
|
from .whisper.timing import median_filter
|
||||||
|
from .whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens, detect_language
|
||||||
|
from .beam import BeamPyTorchInference
|
||||||
|
from .eow_detection import fire_at_boundary, load_cif
|
||||||
|
import os
|
||||||
|
from time import time
|
||||||
|
from .token_buffer import TokenBuffer
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from ..timed_objects import PUNCTUATION_MARKS
|
||||||
|
from .generation_progress import *
|
||||||
|
|
||||||
|
DEC_PAD = 50257
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||||
|
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||||
|
HAS_MLX_WHISPER = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_MLX_WHISPER = False
|
||||||
|
if HAS_MLX_WHISPER:
|
||||||
|
HAS_FASTER_WHISPER = False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||||
|
from faster_whisper.feature_extractor import FeatureExtractor
|
||||||
|
HAS_FASTER_WHISPER = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_FASTER_WHISPER = False
|
||||||
|
|
||||||
|
class PaddedAlignAttWhisper:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: AlignAttConfig,
|
||||||
|
loaded_model=None,
|
||||||
|
mlx_encoder=None,
|
||||||
|
fw_encoder=None,
|
||||||
|
) -> None:
|
||||||
|
self.log_segments = 0
|
||||||
|
|
||||||
|
self.model = loaded_model
|
||||||
|
self.mlx_encoder = mlx_encoder
|
||||||
|
self.fw_encoder = fw_encoder
|
||||||
|
if fw_encoder:
|
||||||
|
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||||
|
|
||||||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
logger.info(f"Model dimensions: {self.model.dims}")
|
||||||
|
self.speaker = -1
|
||||||
|
self.decode_options = DecodingOptions(
|
||||||
|
language = cfg.language,
|
||||||
|
without_timestamps = True,
|
||||||
|
task=cfg.task
|
||||||
|
)
|
||||||
|
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||||
|
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||||
|
# self.create_tokenizer('en')
|
||||||
|
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||||
|
self.global_time_offset = 0.0
|
||||||
|
self.reset_tokenizer_to_auto_next_call = False
|
||||||
|
|
||||||
|
self.max_text_len = self.model.dims.n_text_ctx
|
||||||
|
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.l_hooks = []
|
||||||
|
|
||||||
|
# model to detect end-of-word boundary at the end of the segment
|
||||||
|
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
||||||
|
n_audio_state=self.model.dims.n_audio_state,
|
||||||
|
device=self.model.device)
|
||||||
|
|
||||||
|
# install hooks to access encoder-decoder attention
|
||||||
|
self.dec_attns = []
|
||||||
|
def layer_hook(module, net_input, net_output):
|
||||||
|
# net_output[1]: B*num_head*token_len*audio_len
|
||||||
|
t = F.softmax(net_output[1], dim=-1)
|
||||||
|
self.dec_attns.append(t.squeeze(0))
|
||||||
|
for b in self.model.decoder.blocks:
|
||||||
|
hook = b.cross_attn.register_forward_hook(layer_hook)
|
||||||
|
self.l_hooks.append(hook)
|
||||||
|
|
||||||
|
self.kv_cache = {}
|
||||||
|
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
|
||||||
|
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
|
||||||
|
# save as-is, for the first token or cross attention
|
||||||
|
self.kv_cache[module.cache_id] = net_output
|
||||||
|
else:
|
||||||
|
x = self.kv_cache[module.cache_id]
|
||||||
|
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
|
||||||
|
return self.kv_cache[module.cache_id]
|
||||||
|
|
||||||
|
for i,b in enumerate(self.model.decoder.blocks):
|
||||||
|
hooks = [
|
||||||
|
b.attn.key.register_forward_hook(kv_hook),
|
||||||
|
b.attn.value.register_forward_hook(kv_hook),
|
||||||
|
b.cross_attn.key.register_forward_hook(kv_hook),
|
||||||
|
b.cross_attn.value.register_forward_hook(kv_hook),
|
||||||
|
]
|
||||||
|
self.l_hooks.extend(hooks)
|
||||||
|
|
||||||
|
self.align_source = {}
|
||||||
|
self.num_align_heads = 0
|
||||||
|
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
||||||
|
layer_rank = layer_rank.item()
|
||||||
|
heads = self.align_source.get(layer_rank, [])
|
||||||
|
heads.append((self.num_align_heads, head_id.item()))
|
||||||
|
self.align_source[layer_rank] = heads
|
||||||
|
self.num_align_heads += 1
|
||||||
|
|
||||||
|
|
||||||
|
# tokens to be suppressed from decoding, to prevent hallucinations
|
||||||
|
suppress_tokens = [
|
||||||
|
self.tokenizer.transcribe,
|
||||||
|
self.tokenizer.translate,
|
||||||
|
self.tokenizer.sot,
|
||||||
|
self.tokenizer.sot_prev,
|
||||||
|
self.tokenizer.sot_lm,
|
||||||
|
# self.tokenizer.eot
|
||||||
|
self.tokenizer.no_timestamps, # added by DM
|
||||||
|
] + list(self.tokenizer.all_language_tokens) # added by DM
|
||||||
|
if self.tokenizer.no_speech is not None:
|
||||||
|
suppress_tokens.append(self.tokenizer.no_speech)
|
||||||
|
suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||||
|
logger.debug(f"Suppress tokens: {suppress_tokens}")
|
||||||
|
sup_tokens = SuppressTokens(suppress_tokens)
|
||||||
|
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
||||||
|
# blank tokens are suppresed for new segments near the line 334
|
||||||
|
|
||||||
|
# it's going to be regenerated after lang id
|
||||||
|
self.segments = []
|
||||||
|
self.init_tokens()
|
||||||
|
|
||||||
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.cumulative_time_offset = 0.0
|
||||||
|
self.first_timestamp = None
|
||||||
|
|
||||||
|
if self.cfg.max_context_tokens is None:
|
||||||
|
self.max_context_tokens = self.max_text_len
|
||||||
|
else:
|
||||||
|
self.max_context_tokens = self.cfg.max_context_tokens
|
||||||
|
self.init_context()
|
||||||
|
|
||||||
|
# decoder type: greedy or beam
|
||||||
|
if cfg.decoder_type == "greedy":
|
||||||
|
logger.info("Using greedy decoder")
|
||||||
|
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
||||||
|
self.decoder_type = "greedy"
|
||||||
|
|
||||||
|
elif cfg.decoder_type == "beam":
|
||||||
|
self.decoder_type = "beam"
|
||||||
|
self.inference = BeamPyTorchInference(self.model, self.initial_token_length)
|
||||||
|
self.inference.kv_cache = self.kv_cache
|
||||||
|
|
||||||
|
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||||
|
|
||||||
|
# Tokens to carry over to next chunk for incomplete UTF-8 characters
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
|
||||||
|
def remove_hooks(self):
|
||||||
|
for hook in self.l_hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
def warmup(self, audio):
|
||||||
|
try:
|
||||||
|
self.insert_audio(audio)
|
||||||
|
self.infer(is_last=True)
|
||||||
|
self.refresh_segment(complete=True)
|
||||||
|
logger.info("Model warmed up successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Model warmup failed: {e}")
|
||||||
|
|
||||||
|
def create_tokenizer(self, language=None):
|
||||||
|
self.tokenizer = tokenizer.get_tokenizer(
|
||||||
|
multilingual=self.tokenizer_is_multilingual,
|
||||||
|
language=language,
|
||||||
|
num_languages=self.model.num_languages,
|
||||||
|
task=self.decode_options.task
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_context(self):
|
||||||
|
kw = {'tokenizer': self.tokenizer,
|
||||||
|
'device': self.model.device,
|
||||||
|
'prefix_token_ids': [self.tokenizer.sot_prev]}
|
||||||
|
self.context = TokenBuffer.empty(**kw)
|
||||||
|
if self.cfg.static_init_prompt is not None:
|
||||||
|
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||||
|
if self.cfg.init_prompt is not None:
|
||||||
|
self.context.text += self.cfg.init_prompt
|
||||||
|
|
||||||
|
def init_tokens(self):
|
||||||
|
logger.debug(f"init tokens, {len(self.segments)}")
|
||||||
|
# init tokens (mandatory prompt)
|
||||||
|
self.initial_tokens = torch.tensor(
|
||||||
|
self.tokenizer.sot_sequence_including_notimestamps,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.model.device).unsqueeze(0)
|
||||||
|
self.initial_token_length = self.initial_tokens.shape[1]
|
||||||
|
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||||
|
# self.segments = []
|
||||||
|
logger.debug(f"init tokens after, {len(self.segments)}")
|
||||||
|
self.tokens = [self.initial_tokens]
|
||||||
|
|
||||||
|
def trim_context(self):
|
||||||
|
logger.info("Trimming context")
|
||||||
|
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
||||||
|
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}")
|
||||||
|
logger.info(f"Context text: {self.context.as_text()}")
|
||||||
|
# logger.debug(f"Context tensor: {self.context.as_tensor()}")
|
||||||
|
l = sum(t.shape[1] for t in self.tokens) + c
|
||||||
|
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||||
|
if self.cfg.static_init_prompt is None:
|
||||||
|
after = 0
|
||||||
|
else:
|
||||||
|
after = len(self.cfg.static_init_prompt)
|
||||||
|
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||||
|
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
||||||
|
t = self.context.trim_words(after=after)
|
||||||
|
l -= t
|
||||||
|
c -= t
|
||||||
|
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||||
|
if t == 0:
|
||||||
|
break
|
||||||
|
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||||
|
logger.info(f"Context after trim: {self.context.text} (len: {l})")
|
||||||
|
|
||||||
|
|
||||||
|
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.cfg.decoder_type == "greedy":
|
||||||
|
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Logits shape: {tokens.shape}")
|
||||||
|
logit = self.inference.logits(tokens, audio_features)
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_segment(self, complete=False):
|
||||||
|
|
||||||
|
logger.debug("Refreshing segment:")
|
||||||
|
self.init_tokens()
|
||||||
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.detected_language = None
|
||||||
|
self.cumulative_time_offset = 0.0
|
||||||
|
self.init_context()
|
||||||
|
logger.debug(f"Context: {self.context}")
|
||||||
|
if not complete and len(self.segments) > 2:
|
||||||
|
self.segments = self.segments[-2:]
|
||||||
|
else:
|
||||||
|
logger.debug("removing all segments.")
|
||||||
|
self.segments = []
|
||||||
|
self.log_segments += 1
|
||||||
|
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
|
||||||
|
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||||
|
if self.always_fire: return True
|
||||||
|
if self.never_fire: return False
|
||||||
|
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
||||||
|
|
||||||
|
|
||||||
|
def _current_tokens(self):
|
||||||
|
|
||||||
|
toks = self.tokens
|
||||||
|
# very first infer: duplicate start of seq to beam_size
|
||||||
|
if toks[0].shape[0] == 1:
|
||||||
|
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0)
|
||||||
|
|
||||||
|
if not self.context.is_empty():
|
||||||
|
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
|
||||||
|
toks = [context_toks] + toks
|
||||||
|
|
||||||
|
# make it one tensor
|
||||||
|
if len(toks) > 1:
|
||||||
|
current_tokens = torch.cat(toks, dim=1)
|
||||||
|
else:
|
||||||
|
current_tokens = toks[0]
|
||||||
|
logger.debug("debug print current_tokens:")
|
||||||
|
self.debug_print_tokens(current_tokens)
|
||||||
|
return current_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def debug_print_tokens(self, tokens):
|
||||||
|
for i in range(self.cfg.beam_size):
|
||||||
|
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
|
||||||
|
|
||||||
|
### audio buffer
|
||||||
|
|
||||||
|
def segments_len(self):
|
||||||
|
segments_len = sum(s.shape[0] for s in self.segments) / 16000
|
||||||
|
return segments_len
|
||||||
|
|
||||||
|
def _apply_minseglen(self):
|
||||||
|
segments_len = self.segments_len()
|
||||||
|
# wait for long enough audio to start
|
||||||
|
if segments_len < self.cfg.audio_min_len:
|
||||||
|
logger.debug("waiting for next segment")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def insert_audio(self, segment=None):
|
||||||
|
if segment is not None:
|
||||||
|
self.segments.append(segment)
|
||||||
|
|
||||||
|
removed_len = 0
|
||||||
|
# len of audio is bigger than buffer_len. Going to remove the first segment
|
||||||
|
segments_len = self.segments_len()
|
||||||
|
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||||
|
removed_len = self.segments[0].shape[0] / 16000
|
||||||
|
segments_len -= removed_len
|
||||||
|
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
||||||
|
self.cumulative_time_offset += removed_len # Track cumulative time removed
|
||||||
|
self.segments = self.segments[1:]
|
||||||
|
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
||||||
|
if len(self.tokens) > 1:
|
||||||
|
self.context.append_token_ids(self.tokens[1][0,:].tolist())
|
||||||
|
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||||
|
return removed_len
|
||||||
|
|
||||||
|
def _clean_cache(self):
|
||||||
|
'''clean the cache that stores the attention matrices and kv_cache.
|
||||||
|
It must be called every time after generation with the model.'''
|
||||||
|
# cleaning cache
|
||||||
|
self.dec_attns = []
|
||||||
|
self.kv_cache = {}
|
||||||
|
if self.decoder_type == "beam":
|
||||||
|
self.inference.kv_cache = self.kv_cache
|
||||||
|
self.token_decoder.reset()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def lang_id(self, encoder_features):
|
||||||
|
"""Language detection from encoder features.
|
||||||
|
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
|
||||||
|
"""
|
||||||
|
|
||||||
|
# forward pass using a single token, startoftranscript
|
||||||
|
n_audio = encoder_features.shape[0]
|
||||||
|
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
||||||
|
logits = self.model.logits(x, encoder_features)[:, 0]
|
||||||
|
|
||||||
|
# collect detected languages; suppress all non-language tokens
|
||||||
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
|
mask[list(self.tokenizer.all_language_tokens)] = False
|
||||||
|
logits[:, mask] = -np.inf
|
||||||
|
language_tokens = logits.argmax(dim=-1)
|
||||||
|
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||||
|
language_probs = [
|
||||||
|
{
|
||||||
|
c: language_token_probs[i, j].item()
|
||||||
|
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
|
||||||
|
}
|
||||||
|
for i in range(n_audio)
|
||||||
|
]
|
||||||
|
|
||||||
|
single = encoder_features.ndim == 2
|
||||||
|
if single:
|
||||||
|
language_tokens = language_tokens[0]
|
||||||
|
language_probs = language_probs[0]
|
||||||
|
|
||||||
|
self._clean_cache()
|
||||||
|
return language_tokens, language_probs
|
||||||
|
|
||||||
|
### transcription / translation
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def infer(self, is_last=False):
|
||||||
|
new_segment = True
|
||||||
|
if len(self.segments) == 0:
|
||||||
|
logger.debug("No segments, nothing to do")
|
||||||
|
return []
|
||||||
|
if not self._apply_minseglen():
|
||||||
|
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||||
|
input_segments = torch.cat(self.segments, dim=0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# input_segments is concatenation of audio, it's one array
|
||||||
|
if len(self.segments) > 1:
|
||||||
|
input_segments = torch.cat(self.segments, dim=0)
|
||||||
|
else:
|
||||||
|
input_segments = self.segments[0]
|
||||||
|
|
||||||
|
# if self.cfg.language == "auto" and self.reset_tokenizer_to_auto_next_call:
|
||||||
|
# logger.debug("Resetting tokenizer to auto for new sentence.")
|
||||||
|
# self.create_tokenizer(None)
|
||||||
|
# self.detected_language = None
|
||||||
|
# self.init_tokens()
|
||||||
|
# self.reset_tokenizer_to_auto_next_call = False
|
||||||
|
|
||||||
|
# NEW : we can use a different encoder, before using standart whisper for cross attention with the hooks on the decoder
|
||||||
|
beg_encode = time()
|
||||||
|
if self.mlx_encoder:
|
||||||
|
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
|
||||||
|
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||||
|
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
||||||
|
encoder_feature = torch.as_tensor(mlx_encoder_feature)
|
||||||
|
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
|
||||||
|
elif self.fw_encoder:
|
||||||
|
audio_length_seconds = len(input_segments) / 16000
|
||||||
|
content_mel_len = int(audio_length_seconds * 100)//2
|
||||||
|
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
|
||||||
|
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
||||||
|
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
||||||
|
if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works
|
||||||
|
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
||||||
|
try:
|
||||||
|
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||||
|
except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case:
|
||||||
|
encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device)
|
||||||
|
else:
|
||||||
|
# mel + padding to 30s
|
||||||
|
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||||
|
device=self.device).unsqueeze(0)
|
||||||
|
# trim to 3000
|
||||||
|
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||||
|
# the len of actual audio
|
||||||
|
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||||
|
encoder_feature = self.model.encoder(mel)
|
||||||
|
end_encode = time()
|
||||||
|
# print('Encoder duration:', end_encode-beg_encode)
|
||||||
|
|
||||||
|
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
|
||||||
|
seconds_since_start = self.segments_len() - self.first_timestamp
|
||||||
|
if seconds_since_start >= 2.0:
|
||||||
|
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||||
|
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||||
|
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||||
|
self.create_tokenizer(top_lan)
|
||||||
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
self.cumulative_time_offset = 0.0
|
||||||
|
self.init_tokens()
|
||||||
|
self.init_context()
|
||||||
|
self.detected_language = top_lan
|
||||||
|
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||||
|
|
||||||
|
self.trim_context()
|
||||||
|
current_tokens = self._current_tokens()
|
||||||
|
|
||||||
|
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||||
|
|
||||||
|
|
||||||
|
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
||||||
|
completed = False
|
||||||
|
# punctuation_stop = False
|
||||||
|
|
||||||
|
attn_of_alignment_heads = None
|
||||||
|
most_attended_frame = None
|
||||||
|
|
||||||
|
token_len_before_decoding = current_tokens.shape[1]
|
||||||
|
|
||||||
|
l_absolute_timestamps = []
|
||||||
|
|
||||||
|
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||||
|
|
||||||
|
if new_segment:
|
||||||
|
tokens_for_logits = current_tokens
|
||||||
|
else:
|
||||||
|
# only need to use the last token except in the first forward pass
|
||||||
|
tokens_for_logits = current_tokens[:,-1:]
|
||||||
|
|
||||||
|
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
||||||
|
|
||||||
|
if new_segment and self.tokenizer.no_speech is not None:
|
||||||
|
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
||||||
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||||
|
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||||
|
logger.info("no speech, stop")
|
||||||
|
break
|
||||||
|
|
||||||
|
logits = logits[:, -1, :] # logits for the last token
|
||||||
|
|
||||||
|
# supress blank tokens only at the beginning of the segment
|
||||||
|
if new_segment:
|
||||||
|
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||||
|
new_segment = False
|
||||||
|
self.suppress_tokens(logits)
|
||||||
|
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||||
|
|
||||||
|
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||||
|
self.debug_print_tokens(current_tokens)
|
||||||
|
|
||||||
|
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
||||||
|
for i, attn_mat in enumerate(self.dec_attns):
|
||||||
|
layer_rank = int(i % len(self.model.decoder.blocks))
|
||||||
|
align_heads_in_layer = self.align_source.get(layer_rank, [])
|
||||||
|
if len(align_heads_in_layer) == 0:
|
||||||
|
continue
|
||||||
|
for align_head_rank, head_id in align_heads_in_layer:
|
||||||
|
if self.cfg.beam_size == 1:
|
||||||
|
a = attn_mat[head_id, :, :]
|
||||||
|
a = a.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
a = attn_mat[:, head_id, :, :]
|
||||||
|
attn_of_alignment_heads[align_head_rank].append(a)
|
||||||
|
tmp = []
|
||||||
|
for mat in attn_of_alignment_heads:
|
||||||
|
t = torch.cat(mat, dim=1)
|
||||||
|
tmp.append(t)
|
||||||
|
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||||
|
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
||||||
|
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
||||||
|
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
||||||
|
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||||
|
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
||||||
|
|
||||||
|
# for each beam, the most attended frame is:
|
||||||
|
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
||||||
|
|
||||||
|
# Calculate absolute timestamps accounting for cumulative offset
|
||||||
|
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
||||||
|
|
||||||
|
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||||
|
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
||||||
|
|
||||||
|
most_attended_frame = most_attended_frames[0].item()
|
||||||
|
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||||
|
|
||||||
|
logger.debug("current tokens" + str(current_tokens.shape))
|
||||||
|
if completed:
|
||||||
|
# # stripping the last token, the eot
|
||||||
|
current_tokens = current_tokens[:, :-1]
|
||||||
|
break
|
||||||
|
|
||||||
|
# for some rare cases where the attention fails
|
||||||
|
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
||||||
|
# TODO: check this
|
||||||
|
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
|
||||||
|
logger.debug("ommit rewinding from special tokens")
|
||||||
|
self.last_attend_frame = most_attended_frame
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"[rewind detected] current attention pos: {most_attended_frame}, "
|
||||||
|
f"last attention pos: {self.last_attend_frame}; omit this segment")
|
||||||
|
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||||
|
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.last_attend_frame = most_attended_frame
|
||||||
|
|
||||||
|
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
||||||
|
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
||||||
|
# stripping the last token, the one that is attended too close to the end
|
||||||
|
current_tokens = current_tokens[:, :-1]
|
||||||
|
break
|
||||||
|
|
||||||
|
# debug print
|
||||||
|
for i in range(self.cfg.beam_size):
|
||||||
|
logger.debug("attn: {}, current pos: {}, current token: {}({})".format(
|
||||||
|
attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None,
|
||||||
|
most_attended_frames[i],
|
||||||
|
current_tokens[i, -1].item(),
|
||||||
|
self.tokenizer.decode([current_tokens[i, -1].item()])
|
||||||
|
))
|
||||||
|
|
||||||
|
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||||
|
|
||||||
|
# Prepend pending tokens from previous chunk if any
|
||||||
|
if self.pending_incomplete_tokens:
|
||||||
|
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
|
||||||
|
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
||||||
|
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
||||||
|
|
||||||
|
if fire_detected or is_last: #or punctuation_stop:
|
||||||
|
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||||
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||||
|
else:
|
||||||
|
# going to truncate the tokens after the last space
|
||||||
|
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
|
||||||
|
if len(split_words) > 1:
|
||||||
|
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
||||||
|
else:
|
||||||
|
new_hypothesis = []
|
||||||
|
|
||||||
|
|
||||||
|
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||||
|
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.tokens.append(new_tokens)
|
||||||
|
|
||||||
|
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||||
|
|
||||||
|
self._clean_cache()
|
||||||
|
|
||||||
|
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
|
||||||
|
self.first_timestamp = l_absolute_timestamps[0]
|
||||||
|
|
||||||
|
|
||||||
|
timestamped_words = []
|
||||||
|
timestamp_idx = 0
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
for word, word_tokens in zip(split_words, split_tokens):
|
||||||
|
# Skip words containing incomplete UTF-8 from client output
|
||||||
|
if replacement_char in word:
|
||||||
|
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
timestamp_idx += len(word_tokens)
|
||||||
|
|
||||||
|
timestamp_entry = ASRToken(
|
||||||
|
start=current_timestamp,
|
||||||
|
end=current_timestamp + 0.1,
|
||||||
|
text= word,
|
||||||
|
probability=0.95,
|
||||||
|
speaker=self.speaker,
|
||||||
|
detected_language=self.detected_language
|
||||||
|
).with_offset(
|
||||||
|
self.global_time_offset
|
||||||
|
)
|
||||||
|
timestamped_words.append(timestamp_entry)
|
||||||
|
|
||||||
|
# Hold incomplete tokens for next chunk
|
||||||
|
self.pending_incomplete_tokens = []
|
||||||
|
if split_words and replacement_char in split_words[-1]:
|
||||||
|
self.pending_incomplete_tokens = split_tokens[-1]
|
||||||
|
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
|
||||||
|
|
||||||
|
return timestamped_words
|
||||||
@@ -7,6 +7,7 @@ class TokenBuffer:
|
|||||||
self.prefix_token_ids = prefix_token_ids
|
self.prefix_token_ids = prefix_token_ids
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
def as_token_ids(self, tokenizer=None):
|
def as_token_ids(self, tokenizer=None):
|
||||||
|
|
||||||
@@ -54,8 +55,8 @@ class TokenBuffer:
|
|||||||
|
|
||||||
ids = tokenizer.encode(self.text[after:])
|
ids = tokenizer.encode(self.text[after:])
|
||||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
||||||
print(words, file=sys.stderr)
|
# print(words, file=sys.stderr)
|
||||||
print(wids, file=sys.stderr)
|
# print(wids, file=sys.stderr)
|
||||||
if not words:
|
if not words:
|
||||||
return 0
|
return 0
|
||||||
self.text = self.text[:after] + "".join(words[num:])
|
self.text = self.text[:after] + "".join(words[num:])
|
||||||
@@ -64,7 +65,26 @@ class TokenBuffer:
|
|||||||
def append_token_ids(self, token_ids):
|
def append_token_ids(self, token_ids):
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
assert tokenizer is not None, "Tokenizer is not set."
|
||||||
self.text += self.tokenizer.decode(token_ids)
|
|
||||||
|
all_tokens = self.pending_token_ids + token_ids
|
||||||
|
|
||||||
|
decoded = tokenizer.decode(all_tokens)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
if replacement_char in decoded:
|
||||||
|
if len(all_tokens) > 1:
|
||||||
|
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||||
|
|
||||||
|
if replacement_char not in decoded_partial:
|
||||||
|
self.text += decoded_partial
|
||||||
|
self.pending_token_ids = [all_tokens[-1]]
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.pending_token_ids = all_tokens
|
||||||
|
else:
|
||||||
|
self.text += decoded
|
||||||
|
self.pending_token_ids = []
|
||||||
|
|
||||||
def as_split_word_tokens(self):
|
def as_split_word_tokens(self):
|
||||||
tokenizer = self.tokenizer
|
tokenizer = self.tokenizer
|
||||||
171
whisperlivekit/simul_whisper/whisper/__init__.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import urllib
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||||
|
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||||
|
from .model import ModelDimensions, Whisper
|
||||||
|
from .transcribe import transcribe
|
||||||
|
from .version import __version__
|
||||||
|
|
||||||
|
_MODELS = {
|
||||||
|
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||||
|
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||||
|
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||||
|
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||||
|
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||||
|
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||||
|
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||||
|
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||||
|
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||||
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
|
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||||
|
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||||
|
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||||
|
}
|
||||||
|
|
||||||
|
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||||
|
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||||
|
_ALIGNMENT_HEADS = {
|
||||||
|
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||||
|
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||||
|
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||||
|
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||||
|
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||||
|
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||||
|
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||||
|
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||||
|
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||||
|
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||||
|
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
|
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||||
|
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||||
|
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||||
|
os.makedirs(root, exist_ok=True)
|
||||||
|
|
||||||
|
expected_sha256 = url.split("/")[-2]
|
||||||
|
download_target = os.path.join(root, os.path.basename(url))
|
||||||
|
|
||||||
|
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||||
|
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||||
|
|
||||||
|
if os.path.isfile(download_target):
|
||||||
|
with open(download_target, "rb") as f:
|
||||||
|
model_bytes = f.read()
|
||||||
|
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||||
|
return model_bytes if in_memory else download_target
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||||
|
)
|
||||||
|
|
||||||
|
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||||
|
with tqdm(
|
||||||
|
total=int(source.info().get("Content-Length")),
|
||||||
|
ncols=80,
|
||||||
|
unit="iB",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as loop:
|
||||||
|
while True:
|
||||||
|
buffer = source.read(8192)
|
||||||
|
if not buffer:
|
||||||
|
break
|
||||||
|
|
||||||
|
output.write(buffer)
|
||||||
|
loop.update(len(buffer))
|
||||||
|
|
||||||
|
model_bytes = open(download_target, "rb").read()
|
||||||
|
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_bytes if in_memory else download_target
|
||||||
|
|
||||||
|
|
||||||
|
def available_models() -> List[str]:
|
||||||
|
"""Returns the names of available models"""
|
||||||
|
return list(_MODELS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
name: str,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
download_root: str = None,
|
||||||
|
in_memory: bool = False,
|
||||||
|
decoder_only=False,
|
||||||
|
custom_alignment_heads=None
|
||||||
|
) -> Whisper:
|
||||||
|
"""
|
||||||
|
Load a Whisper ASR model
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
one of the official model names listed by `whisper.available_models()`, or
|
||||||
|
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||||
|
device : Union[str, torch.device]
|
||||||
|
the PyTorch device to put the model into
|
||||||
|
download_root: str
|
||||||
|
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||||
|
in_memory: bool
|
||||||
|
whether to preload the model weights into host memory
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
model : Whisper
|
||||||
|
The Whisper ASR model instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if download_root is None:
|
||||||
|
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||||
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||||
|
|
||||||
|
if name in _MODELS:
|
||||||
|
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||||
|
elif os.path.isfile(name):
|
||||||
|
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Model {name} not found; available models = {available_models()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
|
||||||
|
if custom_alignment_heads:
|
||||||
|
alignment_heads = custom_alignment_heads.encode()
|
||||||
|
|
||||||
|
with (
|
||||||
|
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||||
|
) as fp:
|
||||||
|
checkpoint = torch.load(fp, map_location=device)
|
||||||
|
del checkpoint_file
|
||||||
|
|
||||||
|
dims = ModelDimensions(**checkpoint["dims"])
|
||||||
|
model = Whisper(dims, decoder_only=decoder_only)
|
||||||
|
|
||||||
|
if decoder_only:
|
||||||
|
checkpoint["model_state_dict"] = {
|
||||||
|
k: v for k, v in checkpoint["model_state_dict"].items()
|
||||||
|
if 'encoder' not in k
|
||||||
|
}
|
||||||
|
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
|
if alignment_heads is not None:
|
||||||
|
model.set_alignment_heads(alignment_heads)
|
||||||
|
|
||||||
|
return model.to(device)
|
||||||
3
whisperlivekit/simul_whisper/whisper/__main__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .transcribe import cli
|
||||||
|
|
||||||
|
cli()
|
||||||
50256
whisperlivekit/simul_whisper/whisper/assets/gpt2.tiktoken
Normal file
BIN
whisperlivekit/simul_whisper/whisper/assets/mel_filters.npz
Normal file
50257
whisperlivekit/simul_whisper/whisper/assets/multilingual.tiktoken
Normal file
157
whisperlivekit/simul_whisper/whisper/audio.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from subprocess import CalledProcessError, run
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .utils import exact_div
|
||||||
|
|
||||||
|
# hard-coded audio hyperparameters
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
N_FFT = 400
|
||||||
|
HOP_LENGTH = 160
|
||||||
|
CHUNK_LENGTH = 30
|
||||||
|
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||||
|
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||||
|
|
||||||
|
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||||
|
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||||
|
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||||
|
"""
|
||||||
|
Open an audio file and read as mono waveform, resampling as necessary
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
file: str
|
||||||
|
The audio file to open
|
||||||
|
|
||||||
|
sr: int
|
||||||
|
The sample rate to resample the audio if necessary
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A NumPy array containing the audio waveform, in float32 dtype.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This launches a subprocess to decode audio while down-mixing
|
||||||
|
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||||
|
# fmt: off
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-nostdin",
|
||||||
|
"-threads", "0",
|
||||||
|
"-i", file,
|
||||||
|
"-f", "s16le",
|
||||||
|
"-ac", "1",
|
||||||
|
"-acodec", "pcm_s16le",
|
||||||
|
"-ar", str(sr),
|
||||||
|
"-"
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
try:
|
||||||
|
out = run(cmd, capture_output=True, check=True).stdout
|
||||||
|
except CalledProcessError as e:
|
||||||
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||||
|
|
||||||
|
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||||
|
"""
|
||||||
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||||
|
"""
|
||||||
|
if torch.is_tensor(array):
|
||||||
|
if array.shape[axis] > length:
|
||||||
|
array = array.index_select(
|
||||||
|
dim=axis, index=torch.arange(length, device=array.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
if array.shape[axis] < length:
|
||||||
|
pad_widths = [(0, 0)] * array.ndim
|
||||||
|
pad_widths[axis] = (0, length - array.shape[axis])
|
||||||
|
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
||||||
|
else:
|
||||||
|
if array.shape[axis] > length:
|
||||||
|
array = array.take(indices=range(length), axis=axis)
|
||||||
|
|
||||||
|
if array.shape[axis] < length:
|
||||||
|
pad_widths = [(0, 0)] * array.ndim
|
||||||
|
pad_widths[axis] = (0, length - array.shape[axis])
|
||||||
|
array = np.pad(array, pad_widths)
|
||||||
|
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||||
|
Allows decoupling librosa dependency; saved using:
|
||||||
|
|
||||||
|
np.savez_compressed(
|
||||||
|
"mel_filters.npz",
|
||||||
|
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||||
|
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||||
|
|
||||||
|
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||||
|
with np.load(filters_path, allow_pickle=False) as f:
|
||||||
|
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def log_mel_spectrogram(
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
n_mels: int = 80,
|
||||||
|
padding: int = 0,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute the log-Mel spectrogram of
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
||||||
|
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||||
|
|
||||||
|
n_mels: int
|
||||||
|
The number of Mel-frequency filters, only 80 and 128 are supported
|
||||||
|
|
||||||
|
padding: int
|
||||||
|
Number of zero samples to pad to the right
|
||||||
|
|
||||||
|
device: Optional[Union[str, torch.device]]
|
||||||
|
If given, the audio tensor is moved to this device before STFT
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor, shape = (n_mels, n_frames)
|
||||||
|
A Tensor that contains the Mel spectrogram
|
||||||
|
"""
|
||||||
|
if not torch.is_tensor(audio):
|
||||||
|
if isinstance(audio, str):
|
||||||
|
audio = load_audio(audio)
|
||||||
|
audio = torch.from_numpy(audio)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
audio = audio.to(device)
|
||||||
|
if padding > 0:
|
||||||
|
audio = F.pad(audio, (0, padding))
|
||||||
|
window = torch.hann_window(N_FFT).to(audio.device)
|
||||||
|
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||||
|
magnitudes = stft[..., :-1].abs() ** 2
|
||||||
|
|
||||||
|
filters = mel_filters(audio.device, n_mels)
|
||||||
|
mel_spec = filters @ magnitudes
|
||||||
|
|
||||||
|
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||||
|
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||||
|
log_spec = (log_spec + 4.0) / 4.0
|
||||||
|
return log_spec
|
||||||
826
whisperlivekit/simul_whisper/whisper/decoding.py
Normal file
@@ -0,0 +1,826 @@
|
|||||||
|
from dataclasses import dataclass, field, replace
|
||||||
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
from .audio import CHUNK_LENGTH
|
||||||
|
from .tokenizer import Tokenizer, get_tokenizer
|
||||||
|
from .utils import compression_ratio
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .model import Whisper
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def detect_language(
|
||||||
|
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
||||||
|
) -> Tuple[Tensor, List[dict]]:
|
||||||
|
"""
|
||||||
|
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||||
|
of the most probable language tokens and the probability distribution over all language tokens.
|
||||||
|
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
language_tokens : Tensor, shape = (n_audio,)
|
||||||
|
ids of the most probable language tokens, which appears after the startoftranscript token.
|
||||||
|
language_probs : List[Dict[str, float]], length = n_audio
|
||||||
|
list of dictionaries containing the probability distribution over all languages.
|
||||||
|
"""
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual, num_languages=model.num_languages
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
tokenizer.language is None
|
||||||
|
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"This model doesn't have language tokens so it can't perform lang id"
|
||||||
|
)
|
||||||
|
|
||||||
|
single = mel.ndim == 2
|
||||||
|
if single:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
|
||||||
|
# skip encoder forward pass if already-encoded audio features were given
|
||||||
|
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||||
|
mel = model.encoder(mel)
|
||||||
|
|
||||||
|
# forward pass using a single token, startoftranscript
|
||||||
|
n_audio = mel.shape[0]
|
||||||
|
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||||
|
logits = model.logits(x, mel)[:, 0]
|
||||||
|
|
||||||
|
# collect detected languages; suppress all non-language tokens
|
||||||
|
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||||
|
mask[list(tokenizer.all_language_tokens)] = False
|
||||||
|
logits[:, mask] = -np.inf
|
||||||
|
language_tokens = logits.argmax(dim=-1)
|
||||||
|
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||||
|
language_probs = [
|
||||||
|
{
|
||||||
|
c: language_token_probs[i, j].item()
|
||||||
|
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
||||||
|
}
|
||||||
|
for i in range(n_audio)
|
||||||
|
]
|
||||||
|
|
||||||
|
if single:
|
||||||
|
language_tokens = language_tokens[0]
|
||||||
|
language_probs = language_probs[0]
|
||||||
|
|
||||||
|
return language_tokens, language_probs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DecodingOptions:
|
||||||
|
# whether to perform X->X "transcribe" or X->English "translate"
|
||||||
|
task: str = "transcribe"
|
||||||
|
|
||||||
|
# language that the audio is in; uses detected language if None
|
||||||
|
language: Optional[str] = None
|
||||||
|
|
||||||
|
# sampling-related options
|
||||||
|
temperature: float = 0.0
|
||||||
|
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||||
|
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
||||||
|
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
||||||
|
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
||||||
|
|
||||||
|
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
||||||
|
# to select which to return among the beams or best-of-N samples
|
||||||
|
length_penalty: Optional[float] = None
|
||||||
|
|
||||||
|
# text or tokens to feed as the prompt or the prefix; for more info:
|
||||||
|
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||||
|
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
||||||
|
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
||||||
|
|
||||||
|
# list of tokens ids (or comma-separated token ids) to suppress
|
||||||
|
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||||
|
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||||
|
suppress_blank: bool = True # this will suppress blank outputs
|
||||||
|
|
||||||
|
# timestamp sampling options
|
||||||
|
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||||
|
max_initial_timestamp: Optional[float] = 1.0
|
||||||
|
|
||||||
|
# implementation details
|
||||||
|
fp16: bool = True # use fp16 for most of the calculation
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DecodingResult:
|
||||||
|
audio_features: Tensor
|
||||||
|
language: str
|
||||||
|
language_probs: Optional[Dict[str, float]] = None
|
||||||
|
tokens: List[int] = field(default_factory=list)
|
||||||
|
text: str = ""
|
||||||
|
avg_logprob: float = np.nan
|
||||||
|
no_speech_prob: float = np.nan
|
||||||
|
temperature: float = np.nan
|
||||||
|
compression_ratio: float = np.nan
|
||||||
|
|
||||||
|
|
||||||
|
class Inference:
|
||||||
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
|
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def rearrange_kv_cache(self, source_indices) -> None:
|
||||||
|
"""Update the key-value cache according to the updated beams"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def cleanup_caching(self) -> None:
|
||||||
|
"""Clean up any resources or hooks after decoding is finished"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PyTorchInference(Inference):
|
||||||
|
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||||
|
self.model: "Whisper" = model
|
||||||
|
self.initial_token_length = initial_token_length
|
||||||
|
self.kv_cache = {}
|
||||||
|
self.hooks = []
|
||||||
|
|
||||||
|
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
||||||
|
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
||||||
|
self.kv_modules = key_modules + value_modules
|
||||||
|
|
||||||
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||||
|
if not self.kv_cache:
|
||||||
|
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||||
|
|
||||||
|
if tokens.shape[-1] > self.initial_token_length:
|
||||||
|
# only need to use the last token except in the first forward pass
|
||||||
|
tokens = tokens[:, -1:]
|
||||||
|
|
||||||
|
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||||
|
|
||||||
|
def cleanup_caching(self):
|
||||||
|
for hook in self.hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
self.kv_cache = {}
|
||||||
|
self.hooks = []
|
||||||
|
|
||||||
|
def rearrange_kv_cache(self, source_indices):
|
||||||
|
if source_indices != list(range(len(source_indices))):
|
||||||
|
for module in self.kv_modules:
|
||||||
|
# update the key/value cache to contain the selected sequences
|
||||||
|
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceRanker:
|
||||||
|
def rank(
|
||||||
|
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Given a list of groups of samples and their cumulative log probabilities,
|
||||||
|
return the indices of the samples in each group to select as the final result
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MaximumLikelihoodRanker(SequenceRanker):
|
||||||
|
"""
|
||||||
|
Select the sample with the highest log probabilities, penalized using either
|
||||||
|
a simple length normalization or Google NMT paper's length penalty
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, length_penalty: Optional[float]):
|
||||||
|
self.length_penalty = length_penalty
|
||||||
|
|
||||||
|
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
||||||
|
def scores(logprobs, lengths):
|
||||||
|
result = []
|
||||||
|
for logprob, length in zip(logprobs, lengths):
|
||||||
|
if self.length_penalty is None:
|
||||||
|
penalty = length
|
||||||
|
else:
|
||||||
|
# from the Google NMT paper
|
||||||
|
penalty = ((5 + length) / 6) ** self.length_penalty
|
||||||
|
result.append(logprob / penalty)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# get the sequence with the highest score
|
||||||
|
lengths = [[len(t) for t in s] for s in tokens]
|
||||||
|
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenDecoder:
|
||||||
|
def reset(self):
|
||||||
|
"""Initialize any stateful variables for decoding a new sequence"""
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Tensor, bool]:
|
||||||
|
"""Specify how to select the next token, based on the current trace and logits
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||||
|
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||||
|
|
||||||
|
logits : Tensor, shape = (n_batch, vocab_size)
|
||||||
|
per-token logits of the probability distribution at the current step
|
||||||
|
|
||||||
|
sum_logprobs : Tensor, shape = (n_batch)
|
||||||
|
cumulative log probabilities for each sequence
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
||||||
|
the tokens, appended with the selected next token
|
||||||
|
|
||||||
|
completed : bool
|
||||||
|
True if all sequences has reached the end of text
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def finalize(
|
||||||
|
self, tokens: Tensor, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||||
|
"""Finalize search and return the final candidate sequences
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
||||||
|
all tokens in the context so far, including the prefix and sot_sequence
|
||||||
|
|
||||||
|
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
||||||
|
cumulative log probabilities for each sequence
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
||||||
|
sequence of Tensors containing candidate token sequences, for each audio input
|
||||||
|
|
||||||
|
sum_logprobs : List[List[float]], length = n_audio
|
||||||
|
sequence of cumulative log probabilities corresponding to the above
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class GreedyDecoder(TokenDecoder):
|
||||||
|
def __init__(self, temperature: float, eot: int):
|
||||||
|
self.temperature = temperature
|
||||||
|
self.eot = eot
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Tensor, bool]:
|
||||||
|
if self.temperature == 0:
|
||||||
|
next_tokens = logits.argmax(dim=-1)
|
||||||
|
else:
|
||||||
|
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
||||||
|
|
||||||
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||||
|
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||||
|
|
||||||
|
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||||
|
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||||
|
|
||||||
|
completed = (tokens[:, -1] == self.eot).all()
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
||||||
|
# make sure each sequence has at least one EOT token at the end
|
||||||
|
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||||
|
return tokens, sum_logprobs.tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchDecoder(TokenDecoder):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
beam_size: int,
|
||||||
|
eot: int,
|
||||||
|
inference: Inference,
|
||||||
|
patience: Optional[float] = None,
|
||||||
|
):
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.eot = eot
|
||||||
|
self.inference = inference
|
||||||
|
self.patience = patience or 1.0
|
||||||
|
self.max_candidates: int = round(beam_size * self.patience)
|
||||||
|
self.finished_sequences = None
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.max_candidates > 0
|
||||||
|
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.finished_sequences = None
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||||
|
) -> Tuple[Tensor, bool]:
|
||||||
|
if tokens.shape[0] % self.beam_size != 0:
|
||||||
|
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||||
|
|
||||||
|
n_audio = tokens.shape[0] // self.beam_size
|
||||||
|
if self.finished_sequences is None: # for the first update
|
||||||
|
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||||
|
|
||||||
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
next_tokens, source_indices, finished_sequences = [], [], []
|
||||||
|
for i in range(n_audio):
|
||||||
|
scores, sources, finished = {}, {}, {}
|
||||||
|
|
||||||
|
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
||||||
|
for j in range(self.beam_size):
|
||||||
|
idx = i * self.beam_size + j
|
||||||
|
prefix = tokens[idx].tolist()
|
||||||
|
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
||||||
|
new_logprob = (sum_logprobs[idx] + logprob).item()
|
||||||
|
sequence = tuple(prefix + [token.item()])
|
||||||
|
scores[sequence] = new_logprob
|
||||||
|
sources[sequence] = idx
|
||||||
|
|
||||||
|
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
||||||
|
saved = 0
|
||||||
|
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||||
|
if sequence[-1] == self.eot:
|
||||||
|
finished[sequence] = scores[sequence]
|
||||||
|
else:
|
||||||
|
sum_logprobs[len(next_tokens)] = scores[sequence]
|
||||||
|
next_tokens.append(sequence)
|
||||||
|
source_indices.append(sources[sequence])
|
||||||
|
|
||||||
|
saved += 1
|
||||||
|
if saved == self.beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
finished_sequences.append(finished)
|
||||||
|
|
||||||
|
tokens = torch.tensor(next_tokens, device=tokens.device)
|
||||||
|
self.inference.rearrange_kv_cache(source_indices)
|
||||||
|
|
||||||
|
# add newly finished sequences to self.finished_sequences
|
||||||
|
assert len(self.finished_sequences) == len(finished_sequences)
|
||||||
|
for previously_finished, newly_finished in zip(
|
||||||
|
self.finished_sequences, finished_sequences
|
||||||
|
):
|
||||||
|
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||||
|
if len(previously_finished) >= self.max_candidates:
|
||||||
|
break # the candidate list is full
|
||||||
|
previously_finished[seq] = newly_finished[seq]
|
||||||
|
|
||||||
|
# mark as completed if all audio has enough number of samples
|
||||||
|
completed = all(
|
||||||
|
len(sequences) >= self.max_candidates
|
||||||
|
for sequences in self.finished_sequences
|
||||||
|
)
|
||||||
|
return tokens, completed
|
||||||
|
|
||||||
|
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
||||||
|
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||||
|
sum_logprobs = sum_logprobs.cpu()
|
||||||
|
for i, sequences in enumerate(self.finished_sequences):
|
||||||
|
if (
|
||||||
|
len(sequences) < self.beam_size
|
||||||
|
): # when not enough sequences are finished
|
||||||
|
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||||
|
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||||
|
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||||
|
if len(sequences) >= self.beam_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
tokens: List[List[Tensor]] = [
|
||||||
|
[torch.tensor(seq) for seq in sequences.keys()]
|
||||||
|
for sequences in self.finished_sequences
|
||||||
|
]
|
||||||
|
sum_logprobs: List[List[float]] = [
|
||||||
|
list(sequences.values()) for sequences in self.finished_sequences
|
||||||
|
]
|
||||||
|
return tokens, sum_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
class LogitFilter:
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
||||||
|
"""Apply any filtering or masking to logits in-place
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
logits : Tensor, shape = (n_batch, vocab_size)
|
||||||
|
per-token logits of the probability distribution at the current step
|
||||||
|
|
||||||
|
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||||
|
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressBlank(LogitFilter):
|
||||||
|
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.sample_begin = sample_begin
|
||||||
|
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor):
|
||||||
|
if tokens.shape[1] == self.sample_begin:
|
||||||
|
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressTokens(LogitFilter):
|
||||||
|
def __init__(self, suppress_tokens: Sequence[int]):
|
||||||
|
self.suppress_tokens = list(suppress_tokens)
|
||||||
|
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor):
|
||||||
|
logits[:, self.suppress_tokens] = -np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class ApplyTimestampRules(LogitFilter):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
sample_begin: int,
|
||||||
|
max_initial_timestamp_index: Optional[int],
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.sample_begin = sample_begin
|
||||||
|
self.max_initial_timestamp_index = max_initial_timestamp_index
|
||||||
|
|
||||||
|
def apply(self, logits: Tensor, tokens: Tensor):
|
||||||
|
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||||
|
if self.tokenizer.no_timestamps is not None:
|
||||||
|
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
||||||
|
|
||||||
|
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||||
|
for k in range(tokens.shape[0]):
|
||||||
|
sampled_tokens = tokens[k, self.sample_begin :]
|
||||||
|
seq = [t for t in sampled_tokens.tolist()]
|
||||||
|
last_was_timestamp = (
|
||||||
|
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
penultimate_was_timestamp = (
|
||||||
|
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
|
||||||
|
if last_was_timestamp:
|
||||||
|
if penultimate_was_timestamp: # has to be non-timestamp
|
||||||
|
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
||||||
|
else: # cannot be normal text tokens
|
||||||
|
logits[k, : self.tokenizer.eot] = -np.inf
|
||||||
|
|
||||||
|
timestamps = sampled_tokens[
|
||||||
|
sampled_tokens.ge(self.tokenizer.timestamp_begin)
|
||||||
|
]
|
||||||
|
if timestamps.numel() > 0:
|
||||||
|
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||||
|
# also force each segment to have a nonzero length, to prevent infinite looping
|
||||||
|
if last_was_timestamp and not penultimate_was_timestamp:
|
||||||
|
timestamp_last = timestamps[-1]
|
||||||
|
else:
|
||||||
|
timestamp_last = timestamps[-1] + 1
|
||||||
|
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
|
||||||
|
|
||||||
|
if tokens.shape[1] == self.sample_begin:
|
||||||
|
# suppress generating non-timestamp tokens at the beginning
|
||||||
|
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||||
|
|
||||||
|
# apply the `max_initial_timestamp` option
|
||||||
|
if self.max_initial_timestamp_index is not None:
|
||||||
|
last_allowed = (
|
||||||
|
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||||
|
)
|
||||||
|
logits[:, last_allowed + 1 :] = -np.inf
|
||||||
|
|
||||||
|
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||||
|
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||||
|
for k in range(tokens.shape[0]):
|
||||||
|
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||||
|
if timestamp_logprob > max_text_token_logprob:
|
||||||
|
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||||
|
|
||||||
|
|
||||||
|
class DecodingTask:
|
||||||
|
inference: Inference
|
||||||
|
sequence_ranker: SequenceRanker
|
||||||
|
decoder: TokenDecoder
|
||||||
|
logit_filters: List[LogitFilter]
|
||||||
|
|
||||||
|
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
language = options.language or "en"
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=options.task,
|
||||||
|
)
|
||||||
|
self.tokenizer: Tokenizer = tokenizer
|
||||||
|
self.options: DecodingOptions = self._verify_options(options)
|
||||||
|
|
||||||
|
self.n_group: int = options.beam_size or options.best_of or 1
|
||||||
|
self.n_ctx: int = model.dims.n_text_ctx
|
||||||
|
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
||||||
|
|
||||||
|
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
||||||
|
if self.options.without_timestamps:
|
||||||
|
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
||||||
|
|
||||||
|
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
||||||
|
self.sample_begin: int = len(self.initial_tokens)
|
||||||
|
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||||
|
|
||||||
|
# inference: implements the forward pass through the decoder, including kv caching
|
||||||
|
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
||||||
|
|
||||||
|
# sequence ranker: implements how to rank a group of sampled sequences
|
||||||
|
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||||
|
|
||||||
|
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||||
|
if options.beam_size is not None:
|
||||||
|
self.decoder = BeamSearchDecoder(
|
||||||
|
options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||||
|
|
||||||
|
# logit filters: applies various rules to suppress or penalize certain tokens
|
||||||
|
self.logit_filters = []
|
||||||
|
if self.options.suppress_blank:
|
||||||
|
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
||||||
|
if self.options.suppress_tokens:
|
||||||
|
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
||||||
|
if not options.without_timestamps:
|
||||||
|
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||||
|
max_initial_timestamp_index = None
|
||||||
|
if options.max_initial_timestamp:
|
||||||
|
max_initial_timestamp_index = round(
|
||||||
|
self.options.max_initial_timestamp / precision
|
||||||
|
)
|
||||||
|
self.logit_filters.append(
|
||||||
|
ApplyTimestampRules(
|
||||||
|
tokenizer, self.sample_begin, max_initial_timestamp_index
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||||
|
if options.beam_size is not None and options.best_of is not None:
|
||||||
|
raise ValueError("beam_size and best_of can't be given together")
|
||||||
|
if options.temperature == 0:
|
||||||
|
if options.best_of is not None:
|
||||||
|
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||||
|
if options.patience is not None and options.beam_size is None:
|
||||||
|
raise ValueError("patience requires beam_size to be given")
|
||||||
|
if options.length_penalty is not None and not (
|
||||||
|
0 <= options.length_penalty <= 1
|
||||||
|
):
|
||||||
|
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
def _get_initial_tokens(self) -> Tuple[int]:
|
||||||
|
tokens = list(self.sot_sequence)
|
||||||
|
|
||||||
|
if prefix := self.options.prefix:
|
||||||
|
prefix_tokens = (
|
||||||
|
self.tokenizer.encode(" " + prefix.strip())
|
||||||
|
if isinstance(prefix, str)
|
||||||
|
else prefix
|
||||||
|
)
|
||||||
|
if self.sample_len is not None:
|
||||||
|
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||||
|
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||||
|
tokens = tokens + prefix_tokens
|
||||||
|
|
||||||
|
if prompt := self.options.prompt:
|
||||||
|
prompt_tokens = (
|
||||||
|
self.tokenizer.encode(" " + prompt.strip())
|
||||||
|
if isinstance(prompt, str)
|
||||||
|
else prompt
|
||||||
|
)
|
||||||
|
tokens = (
|
||||||
|
[self.tokenizer.sot_prev]
|
||||||
|
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||||
|
+ tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return tuple(tokens)
|
||||||
|
|
||||||
|
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||||
|
suppress_tokens = self.options.suppress_tokens
|
||||||
|
|
||||||
|
if isinstance(suppress_tokens, str):
|
||||||
|
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
||||||
|
|
||||||
|
if -1 in suppress_tokens:
|
||||||
|
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||||
|
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
||||||
|
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||||
|
suppress_tokens = [] # interpret empty string as an empty list
|
||||||
|
else:
|
||||||
|
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||||
|
|
||||||
|
suppress_tokens.extend(
|
||||||
|
[
|
||||||
|
self.tokenizer.transcribe,
|
||||||
|
self.tokenizer.translate,
|
||||||
|
self.tokenizer.sot,
|
||||||
|
self.tokenizer.sot_prev,
|
||||||
|
self.tokenizer.sot_lm,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.tokenizer.no_speech is not None:
|
||||||
|
# no-speech probability is collected separately
|
||||||
|
suppress_tokens.append(self.tokenizer.no_speech)
|
||||||
|
|
||||||
|
return tuple(sorted(set(suppress_tokens)))
|
||||||
|
|
||||||
|
def _get_audio_features(self, mel: Tensor):
|
||||||
|
if self.options.fp16:
|
||||||
|
mel = mel.half()
|
||||||
|
|
||||||
|
if mel.shape[-2:] == (
|
||||||
|
self.model.dims.n_audio_ctx,
|
||||||
|
self.model.dims.n_audio_state,
|
||||||
|
):
|
||||||
|
# encoded audio features are given; skip audio encoding
|
||||||
|
audio_features = mel
|
||||||
|
else:
|
||||||
|
audio_features = self.model.encoder(mel)
|
||||||
|
|
||||||
|
if audio_features.dtype != (
|
||||||
|
torch.float16 if self.options.fp16 else torch.float32
|
||||||
|
):
|
||||||
|
return TypeError(
|
||||||
|
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio_features
|
||||||
|
|
||||||
|
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
||||||
|
languages = [self.options.language] * audio_features.shape[0]
|
||||||
|
lang_probs = None
|
||||||
|
|
||||||
|
if self.options.language is None or self.options.task == "lang_id":
|
||||||
|
lang_tokens, lang_probs = self.model.detect_language(
|
||||||
|
audio_features, self.tokenizer
|
||||||
|
)
|
||||||
|
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||||
|
if self.options.language is None:
|
||||||
|
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||||
|
|
||||||
|
return languages, lang_probs
|
||||||
|
|
||||||
|
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||||
|
n_batch = tokens.shape[0]
|
||||||
|
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||||
|
no_speech_probs = [np.nan] * n_batch
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i in range(self.sample_len):
|
||||||
|
logits = self.inference.logits(tokens, audio_features)
|
||||||
|
|
||||||
|
if (
|
||||||
|
i == 0 and self.tokenizer.no_speech is not None
|
||||||
|
): # save no_speech_probs
|
||||||
|
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||||
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||||
|
|
||||||
|
# now we need to consider the logits at the last token only
|
||||||
|
logits = logits[:, -1]
|
||||||
|
|
||||||
|
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||||
|
for logit_filter in self.logit_filters:
|
||||||
|
logit_filter.apply(logits, tokens)
|
||||||
|
|
||||||
|
# expand the tokens tensor with the selected next tokens
|
||||||
|
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
||||||
|
|
||||||
|
if completed or tokens.shape[-1] > self.n_ctx:
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
self.inference.cleanup_caching()
|
||||||
|
|
||||||
|
return tokens, sum_logprobs, no_speech_probs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||||
|
self.decoder.reset()
|
||||||
|
tokenizer: Tokenizer = self.tokenizer
|
||||||
|
n_audio: int = mel.shape[0]
|
||||||
|
|
||||||
|
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
||||||
|
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||||
|
|
||||||
|
# detect language if requested, overwriting the language token
|
||||||
|
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||||
|
if self.options.task == "lang_id":
|
||||||
|
return [
|
||||||
|
DecodingResult(
|
||||||
|
audio_features=features, language=language, language_probs=probs
|
||||||
|
)
|
||||||
|
for features, language, probs in zip(
|
||||||
|
audio_features, languages, language_probs
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
||||||
|
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||||
|
|
||||||
|
# call the main sampling loop
|
||||||
|
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
||||||
|
|
||||||
|
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||||
|
audio_features = audio_features[:: self.n_group]
|
||||||
|
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||||
|
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||||
|
|
||||||
|
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||||
|
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||||
|
|
||||||
|
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||||
|
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||||
|
tokens: List[List[Tensor]] = [
|
||||||
|
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||||
|
for s in tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
# select the top-ranked sample in each group
|
||||||
|
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||||
|
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||||
|
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||||
|
|
||||||
|
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||||
|
avg_logprobs: List[float] = [
|
||||||
|
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
||||||
|
]
|
||||||
|
|
||||||
|
fields = (
|
||||||
|
texts,
|
||||||
|
languages,
|
||||||
|
tokens,
|
||||||
|
audio_features,
|
||||||
|
avg_logprobs,
|
||||||
|
no_speech_probs,
|
||||||
|
)
|
||||||
|
if len(set(map(len, fields))) != 1:
|
||||||
|
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
DecodingResult(
|
||||||
|
audio_features=features,
|
||||||
|
language=language,
|
||||||
|
tokens=tokens,
|
||||||
|
text=text,
|
||||||
|
avg_logprob=avg_logprob,
|
||||||
|
no_speech_prob=no_speech_prob,
|
||||||
|
temperature=self.options.temperature,
|
||||||
|
compression_ratio=compression_ratio(text),
|
||||||
|
)
|
||||||
|
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
||||||
|
*fields
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(
|
||||||
|
model: "Whisper",
|
||||||
|
mel: Tensor,
|
||||||
|
options: DecodingOptions = DecodingOptions(),
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[DecodingResult, List[DecodingResult]]:
|
||||||
|
"""
|
||||||
|
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model: Whisper
|
||||||
|
the Whisper model instance
|
||||||
|
|
||||||
|
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
||||||
|
A tensor containing the Mel spectrogram(s)
|
||||||
|
|
||||||
|
options: DecodingOptions
|
||||||
|
A dataclass that contains all necessary options for decoding 30-second segments
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
result: Union[DecodingResult, List[DecodingResult]]
|
||||||
|
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||||
|
"""
|
||||||
|
if single := mel.ndim == 2:
|
||||||
|
mel = mel.unsqueeze(0)
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
options = replace(options, **kwargs)
|
||||||
|
|
||||||
|
result = DecodingTask(model, options).run(mel)
|
||||||
|
|
||||||
|
return result[0] if single else result
|
||||||
350
whisperlivekit/simul_whisper/whisper/model.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
import base64
|
||||||
|
import gzip
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from .decoding import decode as decode_function
|
||||||
|
from .decoding import detect_language as detect_language_function
|
||||||
|
from .transcribe import transcribe as transcribe_function
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
|
SDPA_AVAILABLE = True
|
||||||
|
except (ImportError, RuntimeError, OSError):
|
||||||
|
scaled_dot_product_attention = None
|
||||||
|
SDPA_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelDimensions:
|
||||||
|
n_mels: int
|
||||||
|
n_audio_ctx: int
|
||||||
|
n_audio_state: int
|
||||||
|
n_audio_head: int
|
||||||
|
n_audio_layer: int
|
||||||
|
n_vocab: int
|
||||||
|
n_text_ctx: int
|
||||||
|
n_text_state: int
|
||||||
|
n_text_head: int
|
||||||
|
n_text_layer: int
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.linear(
|
||||||
|
x,
|
||||||
|
self.weight.to(x.dtype),
|
||||||
|
None if self.bias is None else self.bias.to(x.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(nn.Conv1d):
|
||||||
|
def _conv_forward(
|
||||||
|
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||||
|
) -> Tensor:
|
||||||
|
return super()._conv_forward(
|
||||||
|
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sinusoids(length, channels, max_timescale=10000):
|
||||||
|
"""Returns sinusoids for positional embedding"""
|
||||||
|
assert channels % 2 == 0
|
||||||
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||||
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||||
|
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||||
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_sdpa():
|
||||||
|
prev_state = MultiHeadAttention.use_sdpa
|
||||||
|
try:
|
||||||
|
MultiHeadAttention.use_sdpa = False
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
MultiHeadAttention.use_sdpa = prev_state
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
|
||||||
|
|
||||||
|
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.n_head = n_head
|
||||||
|
self.query = Linear(n_state, n_state)
|
||||||
|
self.key = Linear(n_state, n_state, bias=False)
|
||||||
|
self.value = Linear(n_state, n_state)
|
||||||
|
self.out = Linear(n_state, n_state)
|
||||||
|
self.cache_id = cache_id
|
||||||
|
self.key.cache_id = f"{cache_id}_key"
|
||||||
|
self.value.cache_id = f"{cache_id}_value"
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
xa: Optional[Tensor] = None,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
kv_cache: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
q = self.query(x)
|
||||||
|
|
||||||
|
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||||
|
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||||
|
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||||
|
k = self.key(x if xa is None else xa)
|
||||||
|
v = self.value(x if xa is None else xa)
|
||||||
|
else:
|
||||||
|
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||||
|
k = kv_cache[self.key]
|
||||||
|
v = kv_cache[self.value]
|
||||||
|
|
||||||
|
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||||
|
return self.out(wv), qk
|
||||||
|
|
||||||
|
def qkv_attention(
|
||||||
|
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
n_batch, n_ctx, n_state = q.shape
|
||||||
|
scale = (n_state // self.n_head) ** -0.25
|
||||||
|
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
||||||
|
a = scaled_dot_product_attention(
|
||||||
|
q, k, v, is_causal=mask is not None and n_ctx > 1
|
||||||
|
)
|
||||||
|
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
qk = None
|
||||||
|
else:
|
||||||
|
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
||||||
|
if mask is not None:
|
||||||
|
qk = qk + mask[:n_ctx, :n_ctx]
|
||||||
|
qk = qk.float()
|
||||||
|
|
||||||
|
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||||
|
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
qk = qk.detach()
|
||||||
|
|
||||||
|
return out, qk
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
||||||
|
self.attn_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
|
self.cross_attn = (
|
||||||
|
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
||||||
|
)
|
||||||
|
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||||
|
|
||||||
|
n_mlp = n_state * 4
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||||
|
)
|
||||||
|
self.mlp_ln = LayerNorm(n_state)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
xa: Optional[Tensor] = None,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
kv_cache: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||||
|
if self.cross_attn:
|
||||||
|
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||||
|
x = x + self.mlp(self.mlp_ln(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||||
|
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||||
|
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||||
|
|
||||||
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
|
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
|
||||||
|
)
|
||||||
|
self.ln_post = LayerNorm(n_state)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
"""
|
||||||
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||||
|
the mel spectrogram of the audio
|
||||||
|
"""
|
||||||
|
x = F.gelu(self.conv1(x))
|
||||||
|
x = F.gelu(self.conv2(x))
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||||
|
x = (x + self.positional_embedding).to(x.dtype)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.ln_post(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TextDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||||
|
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||||
|
|
||||||
|
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}")
|
||||||
|
for i in range(n_layer)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.ln = LayerNorm(n_state)
|
||||||
|
|
||||||
|
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||||
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||||
|
the text tokens
|
||||||
|
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||||
|
the encoded audio features to be attended on
|
||||||
|
"""
|
||||||
|
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||||
|
x = (
|
||||||
|
self.token_embedding(x)
|
||||||
|
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||||
|
)
|
||||||
|
x = x.to(xa.dtype)
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||||
|
|
||||||
|
x = self.ln(x)
|
||||||
|
logits = (
|
||||||
|
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||||
|
).float()
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class Whisper(nn.Module):
|
||||||
|
def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
if not decoder_only:
|
||||||
|
self.encoder = AudioEncoder(
|
||||||
|
self.dims.n_mels,
|
||||||
|
self.dims.n_audio_ctx,
|
||||||
|
self.dims.n_audio_state,
|
||||||
|
self.dims.n_audio_head,
|
||||||
|
self.dims.n_audio_layer,
|
||||||
|
)
|
||||||
|
self.decoder = TextDecoder(
|
||||||
|
self.dims.n_vocab,
|
||||||
|
self.dims.n_text_ctx,
|
||||||
|
self.dims.n_text_state,
|
||||||
|
self.dims.n_text_head,
|
||||||
|
self.dims.n_text_layer,
|
||||||
|
)
|
||||||
|
# use the last half among the decoder layers for time alignment by default;
|
||||||
|
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||||
|
all_heads = torch.zeros(
|
||||||
|
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||||
|
)
|
||||||
|
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||||
|
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||||
|
|
||||||
|
def set_alignment_heads(self, dump: bytes):
|
||||||
|
array = np.frombuffer(
|
||||||
|
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||||
|
).copy()
|
||||||
|
mask = torch.from_numpy(array).reshape(
|
||||||
|
self.dims.n_text_layer, self.dims.n_text_head
|
||||||
|
)
|
||||||
|
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||||
|
|
||||||
|
def embed_audio(self, mel: torch.Tensor):
|
||||||
|
return self.encoder(mel)
|
||||||
|
|
||||||
|
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||||
|
return self.decoder(tokens, audio_features)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
return self.decoder(tokens, self.encoder(mel))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_multilingual(self):
|
||||||
|
return self.dims.n_vocab >= 51865
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_languages(self):
|
||||||
|
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||||
|
|
||||||
|
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||||
|
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||||
|
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||||
|
intermediate tensors to be reused during later calculations.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
cache : Dict[nn.Module, torch.Tensor]
|
||||||
|
A dictionary object mapping the key/value projection modules to its cache
|
||||||
|
hooks : List[RemovableHandle]
|
||||||
|
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||||
|
"""
|
||||||
|
cache = {**cache} if cache is not None else {}
|
||||||
|
hooks = []
|
||||||
|
|
||||||
|
def save_to_cache(module, _, output):
|
||||||
|
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||||
|
# save as-is, for the first token or cross attention
|
||||||
|
cache[module] = output
|
||||||
|
else:
|
||||||
|
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||||
|
return cache[module]
|
||||||
|
|
||||||
|
def install_hooks(layer: nn.Module):
|
||||||
|
if isinstance(layer, MultiHeadAttention):
|
||||||
|
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||||
|
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||||
|
|
||||||
|
self.decoder.apply(install_hooks)
|
||||||
|
return cache, hooks
|
||||||
|
|
||||||
|
detect_language = detect_language_function
|
||||||
|
transcribe = transcribe_function
|
||||||
|
decode = decode_function
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
||||||
|
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
||||||
80
whisperlivekit/simul_whisper/whisper/normalizers/basic.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
import regex
|
||||||
|
|
||||||
|
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||||
|
ADDITIONAL_DIACRITICS = {
|
||||||
|
"œ": "oe",
|
||||||
|
"Œ": "OE",
|
||||||
|
"ø": "o",
|
||||||
|
"Ø": "O",
|
||||||
|
"æ": "ae",
|
||||||
|
"Æ": "AE",
|
||||||
|
"ß": "ss",
|
||||||
|
"ẞ": "SS",
|
||||||
|
"đ": "d",
|
||||||
|
"Đ": "D",
|
||||||
|
"ð": "d",
|
||||||
|
"Ð": "D",
|
||||||
|
"þ": "th",
|
||||||
|
"Þ": "th",
|
||||||
|
"ł": "l",
|
||||||
|
"Ł": "L",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||||
|
"""
|
||||||
|
Replace any other markers, symbols, and punctuations with a space,
|
||||||
|
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||||
|
"""
|
||||||
|
return "".join(
|
||||||
|
(
|
||||||
|
c
|
||||||
|
if c in keep
|
||||||
|
else (
|
||||||
|
ADDITIONAL_DIACRITICS[c]
|
||||||
|
if c in ADDITIONAL_DIACRITICS
|
||||||
|
else (
|
||||||
|
""
|
||||||
|
if unicodedata.category(c) == "Mn"
|
||||||
|
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for c in unicodedata.normalize("NFKD", s)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_symbols(s: str):
|
||||||
|
"""
|
||||||
|
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||||
|
"""
|
||||||
|
return "".join(
|
||||||
|
" " if unicodedata.category(c)[0] in "MSP" else c
|
||||||
|
for c in unicodedata.normalize("NFKC", s)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTextNormalizer:
|
||||||
|
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||||
|
self.clean = (
|
||||||
|
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||||
|
)
|
||||||
|
self.split_letters = split_letters
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
s = s.lower()
|
||||||
|
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||||
|
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||||
|
s = self.clean(s).lower()
|
||||||
|
|
||||||
|
if self.split_letters:
|
||||||
|
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||||
|
|
||||||
|
s = re.sub(
|
||||||
|
r"\s+", " ", s
|
||||||
|
) # replace any successive whitespace characters with a space
|
||||||
|
|
||||||
|
return s
|
||||||
1741
whisperlivekit/simul_whisper/whisper/normalizers/english.json
Normal file
550
whisperlivekit/simul_whisper/whisper/normalizers/english.py
Normal file
@@ -0,0 +1,550 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from fractions import Fraction
|
||||||
|
from typing import Iterator, List, Match, Optional, Union
|
||||||
|
|
||||||
|
from more_itertools import windowed
|
||||||
|
|
||||||
|
from .basic import remove_symbols_and_diacritics
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishNumberNormalizer:
|
||||||
|
"""
|
||||||
|
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||||
|
|
||||||
|
- remove any commas
|
||||||
|
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||||
|
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||||
|
- spell out `one` and `ones`
|
||||||
|
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.zeros = {"o", "oh", "zero"}
|
||||||
|
self.ones = {
|
||||||
|
name: i
|
||||||
|
for i, name in enumerate(
|
||||||
|
[
|
||||||
|
"one",
|
||||||
|
"two",
|
||||||
|
"three",
|
||||||
|
"four",
|
||||||
|
"five",
|
||||||
|
"six",
|
||||||
|
"seven",
|
||||||
|
"eight",
|
||||||
|
"nine",
|
||||||
|
"ten",
|
||||||
|
"eleven",
|
||||||
|
"twelve",
|
||||||
|
"thirteen",
|
||||||
|
"fourteen",
|
||||||
|
"fifteen",
|
||||||
|
"sixteen",
|
||||||
|
"seventeen",
|
||||||
|
"eighteen",
|
||||||
|
"nineteen",
|
||||||
|
],
|
||||||
|
start=1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.ones_plural = {
|
||||||
|
"sixes" if name == "six" else name + "s": (value, "s")
|
||||||
|
for name, value in self.ones.items()
|
||||||
|
}
|
||||||
|
self.ones_ordinal = {
|
||||||
|
"zeroth": (0, "th"),
|
||||||
|
"first": (1, "st"),
|
||||||
|
"second": (2, "nd"),
|
||||||
|
"third": (3, "rd"),
|
||||||
|
"fifth": (5, "th"),
|
||||||
|
"twelfth": (12, "th"),
|
||||||
|
**{
|
||||||
|
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||||
|
for name, value in self.ones.items()
|
||||||
|
if value > 3 and value != 5 and value != 12
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||||
|
|
||||||
|
self.tens = {
|
||||||
|
"twenty": 20,
|
||||||
|
"thirty": 30,
|
||||||
|
"forty": 40,
|
||||||
|
"fifty": 50,
|
||||||
|
"sixty": 60,
|
||||||
|
"seventy": 70,
|
||||||
|
"eighty": 80,
|
||||||
|
"ninety": 90,
|
||||||
|
}
|
||||||
|
self.tens_plural = {
|
||||||
|
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||||
|
}
|
||||||
|
self.tens_ordinal = {
|
||||||
|
name.replace("y", "ieth"): (value, "th")
|
||||||
|
for name, value in self.tens.items()
|
||||||
|
}
|
||||||
|
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||||
|
|
||||||
|
self.multipliers = {
|
||||||
|
"hundred": 100,
|
||||||
|
"thousand": 1_000,
|
||||||
|
"million": 1_000_000,
|
||||||
|
"billion": 1_000_000_000,
|
||||||
|
"trillion": 1_000_000_000_000,
|
||||||
|
"quadrillion": 1_000_000_000_000_000,
|
||||||
|
"quintillion": 1_000_000_000_000_000_000,
|
||||||
|
"sextillion": 1_000_000_000_000_000_000_000,
|
||||||
|
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||||
|
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||||
|
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||||
|
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||||
|
}
|
||||||
|
self.multipliers_plural = {
|
||||||
|
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||||
|
}
|
||||||
|
self.multipliers_ordinal = {
|
||||||
|
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||||
|
}
|
||||||
|
self.multipliers_suffixed = {
|
||||||
|
**self.multipliers_plural,
|
||||||
|
**self.multipliers_ordinal,
|
||||||
|
}
|
||||||
|
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||||
|
|
||||||
|
self.preceding_prefixers = {
|
||||||
|
"minus": "-",
|
||||||
|
"negative": "-",
|
||||||
|
"plus": "+",
|
||||||
|
"positive": "+",
|
||||||
|
}
|
||||||
|
self.following_prefixers = {
|
||||||
|
"pound": "£",
|
||||||
|
"pounds": "£",
|
||||||
|
"euro": "€",
|
||||||
|
"euros": "€",
|
||||||
|
"dollar": "$",
|
||||||
|
"dollars": "$",
|
||||||
|
"cent": "¢",
|
||||||
|
"cents": "¢",
|
||||||
|
}
|
||||||
|
self.prefixes = set(
|
||||||
|
list(self.preceding_prefixers.values())
|
||||||
|
+ list(self.following_prefixers.values())
|
||||||
|
)
|
||||||
|
self.suffixers = {
|
||||||
|
"per": {"cent": "%"},
|
||||||
|
"percent": "%",
|
||||||
|
}
|
||||||
|
self.specials = {"and", "double", "triple", "point"}
|
||||||
|
|
||||||
|
self.words = set(
|
||||||
|
[
|
||||||
|
key
|
||||||
|
for mapping in [
|
||||||
|
self.zeros,
|
||||||
|
self.ones,
|
||||||
|
self.ones_suffixed,
|
||||||
|
self.tens,
|
||||||
|
self.tens_suffixed,
|
||||||
|
self.multipliers,
|
||||||
|
self.multipliers_suffixed,
|
||||||
|
self.preceding_prefixers,
|
||||||
|
self.following_prefixers,
|
||||||
|
self.suffixers,
|
||||||
|
self.specials,
|
||||||
|
]
|
||||||
|
for key in mapping
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.literal_words = {"one", "ones"}
|
||||||
|
|
||||||
|
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||||
|
prefix: Optional[str] = None
|
||||||
|
value: Optional[Union[str, int]] = None
|
||||||
|
skip = False
|
||||||
|
|
||||||
|
def to_fraction(s: str):
|
||||||
|
try:
|
||||||
|
return Fraction(s)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def output(result: Union[str, int]):
|
||||||
|
nonlocal prefix, value
|
||||||
|
result = str(result)
|
||||||
|
if prefix is not None:
|
||||||
|
result = prefix + result
|
||||||
|
value = None
|
||||||
|
prefix = None
|
||||||
|
return result
|
||||||
|
|
||||||
|
if len(words) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
for prev, current, next in windowed([None] + words + [None], 3):
|
||||||
|
if skip:
|
||||||
|
skip = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||||
|
has_prefix = current[0] in self.prefixes
|
||||||
|
current_without_prefix = current[1:] if has_prefix else current
|
||||||
|
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||||
|
# arabic numbers (potentially with signs and fractions)
|
||||||
|
f = to_fraction(current_without_prefix)
|
||||||
|
assert f is not None
|
||||||
|
if value is not None:
|
||||||
|
if isinstance(value, str) and value.endswith("."):
|
||||||
|
# concatenate decimals / ip address components
|
||||||
|
value = str(value) + str(current)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
|
||||||
|
prefix = current[0] if has_prefix else prefix
|
||||||
|
if f.denominator == 1:
|
||||||
|
value = f.numerator # store integers as int
|
||||||
|
else:
|
||||||
|
value = current_without_prefix
|
||||||
|
elif current not in self.words:
|
||||||
|
# non-numeric words
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.zeros:
|
||||||
|
value = str(value or "") + "0"
|
||||||
|
elif current in self.ones:
|
||||||
|
ones = self.ones[current]
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
value = ones
|
||||||
|
elif isinstance(value, str) or prev in self.ones:
|
||||||
|
if (
|
||||||
|
prev in self.tens and ones < 10
|
||||||
|
): # replace the last zero with the digit
|
||||||
|
assert value[-1] == "0"
|
||||||
|
value = value[:-1] + str(ones)
|
||||||
|
else:
|
||||||
|
value = str(value) + str(ones)
|
||||||
|
elif ones < 10:
|
||||||
|
if value % 10 == 0:
|
||||||
|
value += ones
|
||||||
|
else:
|
||||||
|
value = str(value) + str(ones)
|
||||||
|
else: # eleven to nineteen
|
||||||
|
if value % 100 == 0:
|
||||||
|
value += ones
|
||||||
|
else:
|
||||||
|
value = str(value) + str(ones)
|
||||||
|
elif current in self.ones_suffixed:
|
||||||
|
# ordinal or cardinal; yield the number right away
|
||||||
|
ones, suffix = self.ones_suffixed[current]
|
||||||
|
if value is None:
|
||||||
|
yield output(str(ones) + suffix)
|
||||||
|
elif isinstance(value, str) or prev in self.ones:
|
||||||
|
if prev in self.tens and ones < 10:
|
||||||
|
assert value[-1] == "0"
|
||||||
|
yield output(value[:-1] + str(ones) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(ones) + suffix)
|
||||||
|
elif ones < 10:
|
||||||
|
if value % 10 == 0:
|
||||||
|
yield output(str(value + ones) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(ones) + suffix)
|
||||||
|
else: # eleven to nineteen
|
||||||
|
if value % 100 == 0:
|
||||||
|
yield output(str(value + ones) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(ones) + suffix)
|
||||||
|
value = None
|
||||||
|
elif current in self.tens:
|
||||||
|
tens = self.tens[current]
|
||||||
|
if value is None:
|
||||||
|
value = tens
|
||||||
|
elif isinstance(value, str):
|
||||||
|
value = str(value) + str(tens)
|
||||||
|
else:
|
||||||
|
if value % 100 == 0:
|
||||||
|
value += tens
|
||||||
|
else:
|
||||||
|
value = str(value) + str(tens)
|
||||||
|
elif current in self.tens_suffixed:
|
||||||
|
# ordinal or cardinal; yield the number right away
|
||||||
|
tens, suffix = self.tens_suffixed[current]
|
||||||
|
if value is None:
|
||||||
|
yield output(str(tens) + suffix)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
yield output(str(value) + str(tens) + suffix)
|
||||||
|
else:
|
||||||
|
if value % 100 == 0:
|
||||||
|
yield output(str(value + tens) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + str(tens) + suffix)
|
||||||
|
elif current in self.multipliers:
|
||||||
|
multiplier = self.multipliers[current]
|
||||||
|
if value is None:
|
||||||
|
value = multiplier
|
||||||
|
elif isinstance(value, str) or value == 0:
|
||||||
|
f = to_fraction(value)
|
||||||
|
p = f * multiplier if f is not None else None
|
||||||
|
if f is not None and p.denominator == 1:
|
||||||
|
value = p.numerator
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
value = multiplier
|
||||||
|
else:
|
||||||
|
before = value // 1000 * 1000
|
||||||
|
residual = value % 1000
|
||||||
|
value = before + residual * multiplier
|
||||||
|
elif current in self.multipliers_suffixed:
|
||||||
|
multiplier, suffix = self.multipliers_suffixed[current]
|
||||||
|
if value is None:
|
||||||
|
yield output(str(multiplier) + suffix)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
f = to_fraction(value)
|
||||||
|
p = f * multiplier if f is not None else None
|
||||||
|
if f is not None and p.denominator == 1:
|
||||||
|
yield output(str(p.numerator) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
yield output(str(multiplier) + suffix)
|
||||||
|
else: # int
|
||||||
|
before = value // 1000 * 1000
|
||||||
|
residual = value % 1000
|
||||||
|
value = before + residual * multiplier
|
||||||
|
yield output(str(value) + suffix)
|
||||||
|
value = None
|
||||||
|
elif current in self.preceding_prefixers:
|
||||||
|
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
|
||||||
|
if next in self.words or next_is_numeric:
|
||||||
|
prefix = self.preceding_prefixers[current]
|
||||||
|
else:
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.following_prefixers:
|
||||||
|
# apply prefix (dollars, cents, etc.) only after a number
|
||||||
|
if value is not None:
|
||||||
|
prefix = self.following_prefixers[current]
|
||||||
|
yield output(value)
|
||||||
|
else:
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.suffixers:
|
||||||
|
# apply suffix symbols (percent -> '%')
|
||||||
|
if value is not None:
|
||||||
|
suffix = self.suffixers[current]
|
||||||
|
if isinstance(suffix, dict):
|
||||||
|
if next in suffix:
|
||||||
|
yield output(str(value) + suffix[next])
|
||||||
|
skip = True
|
||||||
|
else:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
else:
|
||||||
|
yield output(str(value) + suffix)
|
||||||
|
else:
|
||||||
|
yield output(current)
|
||||||
|
elif current in self.specials:
|
||||||
|
if next not in self.words and not next_is_numeric:
|
||||||
|
# apply special handling only if the next word can be numeric
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current == "and":
|
||||||
|
# ignore "and" after hundreds, thousands, etc.
|
||||||
|
if prev not in self.multipliers:
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current == "double" or current == "triple":
|
||||||
|
if next in self.ones or next in self.zeros:
|
||||||
|
repeats = 2 if current == "double" else 3
|
||||||
|
ones = self.ones.get(next, 0)
|
||||||
|
value = str(value or "") + str(ones) * repeats
|
||||||
|
skip = True
|
||||||
|
else:
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
yield output(current)
|
||||||
|
elif current == "point":
|
||||||
|
if next in self.decimals or next_is_numeric:
|
||||||
|
value = str(value or "") + "."
|
||||||
|
else:
|
||||||
|
# should all have been covered at this point
|
||||||
|
raise ValueError(f"Unexpected token: {current}")
|
||||||
|
else:
|
||||||
|
# all should have been covered at this point
|
||||||
|
raise ValueError(f"Unexpected token: {current}")
|
||||||
|
|
||||||
|
if value is not None:
|
||||||
|
yield output(value)
|
||||||
|
|
||||||
|
def preprocess(self, s: str):
|
||||||
|
# replace "<number> and a half" with "<number> point five"
|
||||||
|
results = []
|
||||||
|
|
||||||
|
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
if len(segment.strip()) == 0:
|
||||||
|
continue
|
||||||
|
if i == len(segments) - 1:
|
||||||
|
results.append(segment)
|
||||||
|
else:
|
||||||
|
results.append(segment)
|
||||||
|
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||||
|
if last_word in self.decimals or last_word in self.multipliers:
|
||||||
|
results.append("point five")
|
||||||
|
else:
|
||||||
|
results.append("and a half")
|
||||||
|
|
||||||
|
s = " ".join(results)
|
||||||
|
|
||||||
|
# put a space at number/letter boundary
|
||||||
|
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||||
|
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||||
|
|
||||||
|
# but remove spaces which could be a suffix
|
||||||
|
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def postprocess(self, s: str):
|
||||||
|
def combine_cents(m: Match):
|
||||||
|
try:
|
||||||
|
currency = m.group(1)
|
||||||
|
integer = m.group(2)
|
||||||
|
cents = int(m.group(3))
|
||||||
|
return f"{currency}{integer}.{cents:02d}"
|
||||||
|
except ValueError:
|
||||||
|
return m.string
|
||||||
|
|
||||||
|
def extract_cents(m: Match):
|
||||||
|
try:
|
||||||
|
return f"¢{int(m.group(1))}"
|
||||||
|
except ValueError:
|
||||||
|
return m.string
|
||||||
|
|
||||||
|
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||||
|
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||||
|
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||||
|
|
||||||
|
# write "one(s)" instead of "1(s)", just for the readability
|
||||||
|
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
s = self.preprocess(s)
|
||||||
|
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||||
|
s = self.postprocess(s)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishSpellingNormalizer:
|
||||||
|
"""
|
||||||
|
Applies British-American spelling mappings as listed in [1].
|
||||||
|
|
||||||
|
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||||
|
self.mapping = json.load(open(mapping_path))
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishTextNormalizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||||
|
self.replacers = {
|
||||||
|
# common contractions
|
||||||
|
r"\bwon't\b": "will not",
|
||||||
|
r"\bcan't\b": "can not",
|
||||||
|
r"\blet's\b": "let us",
|
||||||
|
r"\bain't\b": "aint",
|
||||||
|
r"\by'all\b": "you all",
|
||||||
|
r"\bwanna\b": "want to",
|
||||||
|
r"\bgotta\b": "got to",
|
||||||
|
r"\bgonna\b": "going to",
|
||||||
|
r"\bi'ma\b": "i am going to",
|
||||||
|
r"\bimma\b": "i am going to",
|
||||||
|
r"\bwoulda\b": "would have",
|
||||||
|
r"\bcoulda\b": "could have",
|
||||||
|
r"\bshoulda\b": "should have",
|
||||||
|
r"\bma'am\b": "madam",
|
||||||
|
# contractions in titles/prefixes
|
||||||
|
r"\bmr\b": "mister ",
|
||||||
|
r"\bmrs\b": "missus ",
|
||||||
|
r"\bst\b": "saint ",
|
||||||
|
r"\bdr\b": "doctor ",
|
||||||
|
r"\bprof\b": "professor ",
|
||||||
|
r"\bcapt\b": "captain ",
|
||||||
|
r"\bgov\b": "governor ",
|
||||||
|
r"\bald\b": "alderman ",
|
||||||
|
r"\bgen\b": "general ",
|
||||||
|
r"\bsen\b": "senator ",
|
||||||
|
r"\brep\b": "representative ",
|
||||||
|
r"\bpres\b": "president ",
|
||||||
|
r"\brev\b": "reverend ",
|
||||||
|
r"\bhon\b": "honorable ",
|
||||||
|
r"\basst\b": "assistant ",
|
||||||
|
r"\bassoc\b": "associate ",
|
||||||
|
r"\blt\b": "lieutenant ",
|
||||||
|
r"\bcol\b": "colonel ",
|
||||||
|
r"\bjr\b": "junior ",
|
||||||
|
r"\bsr\b": "senior ",
|
||||||
|
r"\besq\b": "esquire ",
|
||||||
|
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||||
|
r"'d been\b": " had been",
|
||||||
|
r"'s been\b": " has been",
|
||||||
|
r"'d gone\b": " had gone",
|
||||||
|
r"'s gone\b": " has gone",
|
||||||
|
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||||
|
r"'s got\b": " has got",
|
||||||
|
# general contractions
|
||||||
|
r"n't\b": " not",
|
||||||
|
r"'re\b": " are",
|
||||||
|
r"'s\b": " is",
|
||||||
|
r"'d\b": " would",
|
||||||
|
r"'ll\b": " will",
|
||||||
|
r"'t\b": " not",
|
||||||
|
r"'ve\b": " have",
|
||||||
|
r"'m\b": " am",
|
||||||
|
}
|
||||||
|
self.standardize_numbers = EnglishNumberNormalizer()
|
||||||
|
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||||
|
|
||||||
|
def __call__(self, s: str):
|
||||||
|
s = s.lower()
|
||||||
|
|
||||||
|
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||||
|
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||||
|
s = re.sub(self.ignore_patterns, "", s)
|
||||||
|
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
||||||
|
|
||||||
|
for pattern, replacement in self.replacers.items():
|
||||||
|
s = re.sub(pattern, replacement, s)
|
||||||
|
|
||||||
|
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||||
|
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||||
|
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
||||||
|
|
||||||
|
s = self.standardize_numbers(s)
|
||||||
|
s = self.standardize_spellings(s)
|
||||||
|
|
||||||
|
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||||
|
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||||
|
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||||
|
|
||||||
|
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
||||||
|
|
||||||
|
return s
|
||||||
388
whisperlivekit/simul_whisper/whisper/timing.py
Normal file
@@ -0,0 +1,388 @@
|
|||||||
|
import itertools
|
||||||
|
import subprocess
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
|
import numba
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .model import Whisper
|
||||||
|
|
||||||
|
|
||||||
|
def median_filter(x: torch.Tensor, filter_width: int):
|
||||||
|
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||||||
|
pad_width = filter_width // 2
|
||||||
|
if x.shape[-1] <= pad_width:
|
||||||
|
# F.pad requires the padding width to be smaller than the input dimension
|
||||||
|
return x
|
||||||
|
|
||||||
|
if (ndim := x.ndim) <= 2:
|
||||||
|
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
||||||
|
x = x[None, None, :]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
filter_width > 0 and filter_width % 2 == 1
|
||||||
|
), "`filter_width` should be an odd number"
|
||||||
|
|
||||||
|
result = None
|
||||||
|
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
||||||
|
if x.is_cuda:
|
||||||
|
try:
|
||||||
|
from .triton_ops import median_filter_cuda
|
||||||
|
|
||||||
|
result = median_filter_cuda(x, filter_width)
|
||||||
|
except (RuntimeError, subprocess.CalledProcessError):
|
||||||
|
warnings.warn(
|
||||||
|
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||||
|
"falling back to a slower median kernel implementation..."
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||||
|
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
||||||
|
|
||||||
|
if ndim <= 2:
|
||||||
|
result = result[0, 0]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@numba.jit(nopython=True)
|
||||||
|
def backtrace(trace: np.ndarray):
|
||||||
|
i = trace.shape[0] - 1
|
||||||
|
j = trace.shape[1] - 1
|
||||||
|
trace[0, :] = 2
|
||||||
|
trace[:, 0] = 1
|
||||||
|
|
||||||
|
result = []
|
||||||
|
while i > 0 or j > 0:
|
||||||
|
result.append((i - 1, j - 1))
|
||||||
|
|
||||||
|
if trace[i, j] == 0:
|
||||||
|
i -= 1
|
||||||
|
j -= 1
|
||||||
|
elif trace[i, j] == 1:
|
||||||
|
i -= 1
|
||||||
|
elif trace[i, j] == 2:
|
||||||
|
j -= 1
|
||||||
|
else:
|
||||||
|
raise ValueError("Unexpected trace[i, j]")
|
||||||
|
|
||||||
|
result = np.array(result)
|
||||||
|
return result[::-1, :].T
|
||||||
|
|
||||||
|
|
||||||
|
@numba.jit(nopython=True, parallel=True)
|
||||||
|
def dtw_cpu(x: np.ndarray):
|
||||||
|
N, M = x.shape
|
||||||
|
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||||
|
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||||
|
|
||||||
|
cost[0, 0] = 0
|
||||||
|
for j in range(1, M + 1):
|
||||||
|
for i in range(1, N + 1):
|
||||||
|
c0 = cost[i - 1, j - 1]
|
||||||
|
c1 = cost[i - 1, j]
|
||||||
|
c2 = cost[i, j - 1]
|
||||||
|
|
||||||
|
if c0 < c1 and c0 < c2:
|
||||||
|
c, t = c0, 0
|
||||||
|
elif c1 < c0 and c1 < c2:
|
||||||
|
c, t = c1, 1
|
||||||
|
else:
|
||||||
|
c, t = c2, 2
|
||||||
|
|
||||||
|
cost[i, j] = x[i - 1, j - 1] + c
|
||||||
|
trace[i, j] = t
|
||||||
|
|
||||||
|
return backtrace(trace)
|
||||||
|
|
||||||
|
|
||||||
|
def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||||
|
from .triton_ops import dtw_kernel
|
||||||
|
|
||||||
|
M, N = x.shape
|
||||||
|
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||||
|
|
||||||
|
x_skew = (
|
||||||
|
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||||
|
)
|
||||||
|
x_skew = x_skew.T.contiguous()
|
||||||
|
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||||
|
cost[0, 0] = 0
|
||||||
|
cost = cost.to(x.device)
|
||||||
|
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||||
|
|
||||||
|
dtw_kernel[(1,)](
|
||||||
|
cost,
|
||||||
|
trace,
|
||||||
|
x_skew,
|
||||||
|
x_skew.stride(0),
|
||||||
|
cost.stride(0),
|
||||||
|
trace.stride(0),
|
||||||
|
N,
|
||||||
|
M,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
||||||
|
:, : N + 1
|
||||||
|
]
|
||||||
|
return backtrace(trace.cpu().numpy())
|
||||||
|
|
||||||
|
|
||||||
|
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||||
|
if x.is_cuda:
|
||||||
|
try:
|
||||||
|
return dtw_cuda(x)
|
||||||
|
except (RuntimeError, subprocess.CalledProcessError):
|
||||||
|
warnings.warn(
|
||||||
|
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||||
|
"falling back to a slower DTW implementation..."
|
||||||
|
)
|
||||||
|
|
||||||
|
return dtw_cpu(x.double().cpu().numpy())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WordTiming:
|
||||||
|
word: str
|
||||||
|
tokens: List[int]
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
probability: float
|
||||||
|
|
||||||
|
|
||||||
|
def find_alignment(
|
||||||
|
model: "Whisper",
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
text_tokens: List[int],
|
||||||
|
mel: torch.Tensor,
|
||||||
|
num_frames: int,
|
||||||
|
*,
|
||||||
|
medfilt_width: int = 7,
|
||||||
|
qk_scale: float = 1.0,
|
||||||
|
) -> List[WordTiming]:
|
||||||
|
if len(text_tokens) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tokens = torch.tensor(
|
||||||
|
[
|
||||||
|
*tokenizer.sot_sequence,
|
||||||
|
tokenizer.no_timestamps,
|
||||||
|
*text_tokens,
|
||||||
|
tokenizer.eot,
|
||||||
|
]
|
||||||
|
).to(model.device)
|
||||||
|
|
||||||
|
# install hooks on the cross attention layers to retrieve the attention weights
|
||||||
|
QKs = [None] * model.dims.n_text_layer
|
||||||
|
hooks = [
|
||||||
|
block.cross_attn.register_forward_hook(
|
||||||
|
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
||||||
|
)
|
||||||
|
for i, block in enumerate(model.decoder.blocks)
|
||||||
|
]
|
||||||
|
|
||||||
|
from .model import disable_sdpa
|
||||||
|
|
||||||
|
with torch.no_grad(), disable_sdpa():
|
||||||
|
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||||
|
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||||
|
token_probs = sampled_logits.softmax(dim=-1)
|
||||||
|
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||||
|
text_token_probs = text_token_probs.tolist()
|
||||||
|
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
# heads * tokens * frames
|
||||||
|
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||||
|
weights = weights[:, :, : num_frames // 2]
|
||||||
|
weights = (weights * qk_scale).softmax(dim=-1)
|
||||||
|
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||||
|
weights = (weights - mean) / std
|
||||||
|
weights = median_filter(weights, medfilt_width)
|
||||||
|
|
||||||
|
matrix = weights.mean(axis=0)
|
||||||
|
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||||
|
text_indices, time_indices = dtw(-matrix)
|
||||||
|
|
||||||
|
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||||
|
if len(word_tokens) <= 1:
|
||||||
|
# return on eot only
|
||||||
|
# >>> np.pad([], (1, 0))
|
||||||
|
# array([0.])
|
||||||
|
# This results in crashes when we lookup jump_times with float, like
|
||||||
|
# IndexError: arrays used as indices must be of integer (or boolean) type
|
||||||
|
return []
|
||||||
|
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||||
|
|
||||||
|
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||||
|
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||||
|
start_times = jump_times[word_boundaries[:-1]]
|
||||||
|
end_times = jump_times[word_boundaries[1:]]
|
||||||
|
word_probabilities = [
|
||||||
|
np.mean(text_token_probs[i:j])
|
||||||
|
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
WordTiming(word, tokens, start, end, probability)
|
||||||
|
for word, tokens, start, end, probability in zip(
|
||||||
|
words, word_tokens, start_times, end_times, word_probabilities
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
||||||
|
# merge prepended punctuations
|
||||||
|
i = len(alignment) - 2
|
||||||
|
j = len(alignment) - 1
|
||||||
|
while i >= 0:
|
||||||
|
previous = alignment[i]
|
||||||
|
following = alignment[j]
|
||||||
|
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
||||||
|
# prepend it to the following word
|
||||||
|
following.word = previous.word + following.word
|
||||||
|
following.tokens = previous.tokens + following.tokens
|
||||||
|
previous.word = ""
|
||||||
|
previous.tokens = []
|
||||||
|
else:
|
||||||
|
j = i
|
||||||
|
i -= 1
|
||||||
|
|
||||||
|
# merge appended punctuations
|
||||||
|
i = 0
|
||||||
|
j = 1
|
||||||
|
while j < len(alignment):
|
||||||
|
previous = alignment[i]
|
||||||
|
following = alignment[j]
|
||||||
|
if not previous.word.endswith(" ") and following.word in appended:
|
||||||
|
# append it to the previous word
|
||||||
|
previous.word = previous.word + following.word
|
||||||
|
previous.tokens = previous.tokens + following.tokens
|
||||||
|
following.word = ""
|
||||||
|
following.tokens = []
|
||||||
|
else:
|
||||||
|
i = j
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
|
||||||
|
def add_word_timestamps(
|
||||||
|
*,
|
||||||
|
segments: List[dict],
|
||||||
|
model: "Whisper",
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
mel: torch.Tensor,
|
||||||
|
num_frames: int,
|
||||||
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
|
last_speech_timestamp: float,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if len(segments) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
text_tokens_per_segment = [
|
||||||
|
[token for token in segment["tokens"] if token < tokenizer.eot]
|
||||||
|
for segment in segments
|
||||||
|
]
|
||||||
|
|
||||||
|
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||||
|
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||||
|
word_durations = np.array([t.end - t.start for t in alignment])
|
||||||
|
word_durations = word_durations[word_durations.nonzero()]
|
||||||
|
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||||
|
median_duration = min(0.7, float(median_duration))
|
||||||
|
max_duration = median_duration * 2
|
||||||
|
|
||||||
|
# hack: truncate long words at sentence boundaries.
|
||||||
|
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||||
|
if len(word_durations) > 0:
|
||||||
|
sentence_end_marks = ".。!!??"
|
||||||
|
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
||||||
|
for i in range(1, len(alignment)):
|
||||||
|
if alignment[i].end - alignment[i].start > max_duration:
|
||||||
|
if alignment[i].word in sentence_end_marks:
|
||||||
|
alignment[i].end = alignment[i].start + max_duration
|
||||||
|
elif alignment[i - 1].word in sentence_end_marks:
|
||||||
|
alignment[i].start = alignment[i].end - max_duration
|
||||||
|
|
||||||
|
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||||
|
|
||||||
|
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||||
|
word_index = 0
|
||||||
|
|
||||||
|
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
||||||
|
saved_tokens = 0
|
||||||
|
words = []
|
||||||
|
|
||||||
|
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
||||||
|
timing = alignment[word_index]
|
||||||
|
|
||||||
|
if timing.word:
|
||||||
|
words.append(
|
||||||
|
dict(
|
||||||
|
word=timing.word,
|
||||||
|
start=round(time_offset + timing.start, 2),
|
||||||
|
end=round(time_offset + timing.end, 2),
|
||||||
|
probability=timing.probability,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_tokens += len(timing.tokens)
|
||||||
|
word_index += 1
|
||||||
|
|
||||||
|
# hack: truncate long words at segment boundaries.
|
||||||
|
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||||
|
if len(words) > 0:
|
||||||
|
# ensure the first and second word after a pause is not longer than
|
||||||
|
# twice the median word duration.
|
||||||
|
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
||||||
|
words[0]["end"] - words[0]["start"] > max_duration
|
||||||
|
or (
|
||||||
|
len(words) > 1
|
||||||
|
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
len(words) > 1
|
||||||
|
and words[1]["end"] - words[1]["start"] > max_duration
|
||||||
|
):
|
||||||
|
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
||||||
|
words[0]["end"] = words[1]["start"] = boundary
|
||||||
|
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
||||||
|
|
||||||
|
# prefer the segment-level start timestamp if the first word is too long.
|
||||||
|
if (
|
||||||
|
segment["start"] < words[0]["end"]
|
||||||
|
and segment["start"] - 0.5 > words[0]["start"]
|
||||||
|
):
|
||||||
|
words[0]["start"] = max(
|
||||||
|
0, min(words[0]["end"] - median_duration, segment["start"])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
segment["start"] = words[0]["start"]
|
||||||
|
|
||||||
|
# prefer the segment-level end timestamp if the last word is too long.
|
||||||
|
if (
|
||||||
|
segment["end"] > words[-1]["start"]
|
||||||
|
and segment["end"] + 0.5 < words[-1]["end"]
|
||||||
|
):
|
||||||
|
words[-1]["end"] = max(
|
||||||
|
words[-1]["start"] + median_duration, segment["end"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
segment["end"] = words[-1]["end"]
|
||||||
|
|
||||||
|
last_speech_timestamp = segment["end"]
|
||||||
|
|
||||||
|
segment["words"] = words
|
||||||
395
whisperlivekit/simul_whisper/whisper/tokenizer.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import string
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import cached_property, lru_cache
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
LANGUAGES = {
|
||||||
|
"en": "english",
|
||||||
|
"zh": "chinese",
|
||||||
|
"de": "german",
|
||||||
|
"es": "spanish",
|
||||||
|
"ru": "russian",
|
||||||
|
"ko": "korean",
|
||||||
|
"fr": "french",
|
||||||
|
"ja": "japanese",
|
||||||
|
"pt": "portuguese",
|
||||||
|
"tr": "turkish",
|
||||||
|
"pl": "polish",
|
||||||
|
"ca": "catalan",
|
||||||
|
"nl": "dutch",
|
||||||
|
"ar": "arabic",
|
||||||
|
"sv": "swedish",
|
||||||
|
"it": "italian",
|
||||||
|
"id": "indonesian",
|
||||||
|
"hi": "hindi",
|
||||||
|
"fi": "finnish",
|
||||||
|
"vi": "vietnamese",
|
||||||
|
"he": "hebrew",
|
||||||
|
"uk": "ukrainian",
|
||||||
|
"el": "greek",
|
||||||
|
"ms": "malay",
|
||||||
|
"cs": "czech",
|
||||||
|
"ro": "romanian",
|
||||||
|
"da": "danish",
|
||||||
|
"hu": "hungarian",
|
||||||
|
"ta": "tamil",
|
||||||
|
"no": "norwegian",
|
||||||
|
"th": "thai",
|
||||||
|
"ur": "urdu",
|
||||||
|
"hr": "croatian",
|
||||||
|
"bg": "bulgarian",
|
||||||
|
"lt": "lithuanian",
|
||||||
|
"la": "latin",
|
||||||
|
"mi": "maori",
|
||||||
|
"ml": "malayalam",
|
||||||
|
"cy": "welsh",
|
||||||
|
"sk": "slovak",
|
||||||
|
"te": "telugu",
|
||||||
|
"fa": "persian",
|
||||||
|
"lv": "latvian",
|
||||||
|
"bn": "bengali",
|
||||||
|
"sr": "serbian",
|
||||||
|
"az": "azerbaijani",
|
||||||
|
"sl": "slovenian",
|
||||||
|
"kn": "kannada",
|
||||||
|
"et": "estonian",
|
||||||
|
"mk": "macedonian",
|
||||||
|
"br": "breton",
|
||||||
|
"eu": "basque",
|
||||||
|
"is": "icelandic",
|
||||||
|
"hy": "armenian",
|
||||||
|
"ne": "nepali",
|
||||||
|
"mn": "mongolian",
|
||||||
|
"bs": "bosnian",
|
||||||
|
"kk": "kazakh",
|
||||||
|
"sq": "albanian",
|
||||||
|
"sw": "swahili",
|
||||||
|
"gl": "galician",
|
||||||
|
"mr": "marathi",
|
||||||
|
"pa": "punjabi",
|
||||||
|
"si": "sinhala",
|
||||||
|
"km": "khmer",
|
||||||
|
"sn": "shona",
|
||||||
|
"yo": "yoruba",
|
||||||
|
"so": "somali",
|
||||||
|
"af": "afrikaans",
|
||||||
|
"oc": "occitan",
|
||||||
|
"ka": "georgian",
|
||||||
|
"be": "belarusian",
|
||||||
|
"tg": "tajik",
|
||||||
|
"sd": "sindhi",
|
||||||
|
"gu": "gujarati",
|
||||||
|
"am": "amharic",
|
||||||
|
"yi": "yiddish",
|
||||||
|
"lo": "lao",
|
||||||
|
"uz": "uzbek",
|
||||||
|
"fo": "faroese",
|
||||||
|
"ht": "haitian creole",
|
||||||
|
"ps": "pashto",
|
||||||
|
"tk": "turkmen",
|
||||||
|
"nn": "nynorsk",
|
||||||
|
"mt": "maltese",
|
||||||
|
"sa": "sanskrit",
|
||||||
|
"lb": "luxembourgish",
|
||||||
|
"my": "myanmar",
|
||||||
|
"bo": "tibetan",
|
||||||
|
"tl": "tagalog",
|
||||||
|
"mg": "malagasy",
|
||||||
|
"as": "assamese",
|
||||||
|
"tt": "tatar",
|
||||||
|
"haw": "hawaiian",
|
||||||
|
"ln": "lingala",
|
||||||
|
"ha": "hausa",
|
||||||
|
"ba": "bashkir",
|
||||||
|
"jw": "javanese",
|
||||||
|
"su": "sundanese",
|
||||||
|
"yue": "cantonese",
|
||||||
|
}
|
||||||
|
|
||||||
|
# language code lookup by name, with a few language aliases
|
||||||
|
TO_LANGUAGE_CODE = {
|
||||||
|
**{language: code for code, language in LANGUAGES.items()},
|
||||||
|
"burmese": "my",
|
||||||
|
"valencian": "ca",
|
||||||
|
"flemish": "nl",
|
||||||
|
"haitian": "ht",
|
||||||
|
"letzeburgesch": "lb",
|
||||||
|
"pushto": "ps",
|
||||||
|
"panjabi": "pa",
|
||||||
|
"moldavian": "ro",
|
||||||
|
"moldovan": "ro",
|
||||||
|
"sinhalese": "si",
|
||||||
|
"castilian": "es",
|
||||||
|
"mandarin": "zh",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Tokenizer:
|
||||||
|
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||||
|
|
||||||
|
encoding: tiktoken.Encoding
|
||||||
|
num_languages: int
|
||||||
|
language: Optional[str] = None
|
||||||
|
task: Optional[str] = None
|
||||||
|
sot_sequence: Tuple[int] = ()
|
||||||
|
special_tokens: Dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
for special in self.encoding.special_tokens_set:
|
||||||
|
special_token = self.encoding.encode_single_token(special)
|
||||||
|
self.special_tokens[special] = special_token
|
||||||
|
|
||||||
|
sot: int = self.special_tokens["<|startoftranscript|>"]
|
||||||
|
translate: int = self.special_tokens["<|translate|>"]
|
||||||
|
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||||
|
|
||||||
|
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
||||||
|
sot_sequence = [sot]
|
||||||
|
if self.language is not None:
|
||||||
|
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||||
|
if self.task is not None:
|
||||||
|
task_token: int = transcribe if self.task == "transcribe" else translate
|
||||||
|
sot_sequence.append(task_token)
|
||||||
|
|
||||||
|
self.sot_sequence = tuple(sot_sequence)
|
||||||
|
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
return self.encoding.encode(text, **kwargs)
|
||||||
|
|
||||||
|
def decode(self, token_ids: List[int], **kwargs) -> str:
|
||||||
|
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
||||||
|
return self.encoding.decode(token_ids, **kwargs)
|
||||||
|
|
||||||
|
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
||||||
|
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||||
|
"""
|
||||||
|
return self.encoding.decode(token_ids, **kwargs)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def eot(self) -> int:
|
||||||
|
return self.encoding.eot_token
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def transcribe(self) -> int:
|
||||||
|
return self.special_tokens["<|transcribe|>"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def translate(self) -> int:
|
||||||
|
return self.special_tokens["<|translate|>"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sot(self) -> int:
|
||||||
|
return self.special_tokens["<|startoftranscript|>"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sot_lm(self) -> int:
|
||||||
|
return self.special_tokens["<|startoflm|>"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sot_prev(self) -> int:
|
||||||
|
return self.special_tokens["<|startofprev|>"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def no_speech(self) -> int:
|
||||||
|
return self.special_tokens["<|nospeech|>"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def no_timestamps(self) -> int:
|
||||||
|
return self.special_tokens["<|notimestamps|>"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def timestamp_begin(self) -> int:
|
||||||
|
return self.special_tokens["<|0.00|>"]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def language_token(self) -> int:
|
||||||
|
"""Returns the token id corresponding to the value of the `language` field"""
|
||||||
|
if self.language is None:
|
||||||
|
raise ValueError("This tokenizer does not have language token configured")
|
||||||
|
|
||||||
|
return self.to_language_token(self.language)
|
||||||
|
|
||||||
|
def to_language_token(self, language):
|
||||||
|
if token := self.special_tokens.get(f"<|{language}|>", None):
|
||||||
|
return token
|
||||||
|
|
||||||
|
raise KeyError(f"Language {language} not found in tokenizer.")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def all_language_tokens(self) -> Tuple[int]:
|
||||||
|
result = []
|
||||||
|
for token, token_id in self.special_tokens.items():
|
||||||
|
if token.strip("<|>") in LANGUAGES:
|
||||||
|
result.append(token_id)
|
||||||
|
return tuple(result)[: self.num_languages]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def all_language_codes(self) -> Tuple[str]:
|
||||||
|
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||||
|
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def non_speech_tokens(self) -> Tuple[int]:
|
||||||
|
"""
|
||||||
|
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||||
|
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||||
|
|
||||||
|
- ♪♪♪
|
||||||
|
- ( SPEAKING FOREIGN LANGUAGE )
|
||||||
|
- [DAVID] Hey there,
|
||||||
|
|
||||||
|
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||||
|
"""
|
||||||
|
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||||
|
symbols += (
|
||||||
|
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||||
|
)
|
||||||
|
|
||||||
|
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||||
|
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||||
|
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||||
|
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||||
|
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||||
|
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||||
|
|
||||||
|
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||||
|
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
||||||
|
for symbol in symbols + list(miscellaneous):
|
||||||
|
for tokens in [
|
||||||
|
self.encoding.encode(symbol),
|
||||||
|
self.encoding.encode(" " + symbol),
|
||||||
|
]:
|
||||||
|
if len(tokens) == 1 or symbol in miscellaneous:
|
||||||
|
result.add(tokens[0])
|
||||||
|
|
||||||
|
return tuple(sorted(result))
|
||||||
|
|
||||||
|
def split_to_word_tokens(self, tokens: List[int]):
|
||||||
|
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
||||||
|
# These languages don't typically use spaces, so it is difficult to split words
|
||||||
|
# without morpheme analysis. Here, we instead split words at any
|
||||||
|
# position where the tokens are decoded as valid unicode points
|
||||||
|
return self.split_tokens_on_unicode(tokens)
|
||||||
|
|
||||||
|
return self.split_tokens_on_spaces(tokens)
|
||||||
|
|
||||||
|
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||||
|
decoded_full = self.decode_with_timestamps(tokens)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
words = []
|
||||||
|
word_tokens = []
|
||||||
|
current_tokens = []
|
||||||
|
unicode_offset = 0
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
current_tokens.append(token)
|
||||||
|
decoded = self.decode_with_timestamps(current_tokens)
|
||||||
|
|
||||||
|
if (
|
||||||
|
replacement_char not in decoded
|
||||||
|
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||||
|
== replacement_char
|
||||||
|
):
|
||||||
|
words.append(decoded)
|
||||||
|
word_tokens.append(current_tokens)
|
||||||
|
current_tokens = []
|
||||||
|
unicode_offset += len(decoded)
|
||||||
|
|
||||||
|
return words, word_tokens
|
||||||
|
|
||||||
|
def split_tokens_on_spaces(self, tokens: List[int]):
|
||||||
|
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
||||||
|
words = []
|
||||||
|
word_tokens = []
|
||||||
|
|
||||||
|
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
||||||
|
special = subword_tokens[0] >= self.eot
|
||||||
|
with_space = subword.startswith(" ")
|
||||||
|
punctuation = subword.strip() in string.punctuation
|
||||||
|
if special or with_space or punctuation or len(words) == 0:
|
||||||
|
words.append(subword)
|
||||||
|
word_tokens.append(subword_tokens)
|
||||||
|
else:
|
||||||
|
words[-1] = words[-1] + subword
|
||||||
|
word_tokens[-1].extend(subword_tokens)
|
||||||
|
|
||||||
|
return words, word_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||||
|
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||||
|
ranks = {
|
||||||
|
base64.b64decode(token): int(rank)
|
||||||
|
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||||
|
}
|
||||||
|
n_vocab = len(ranks)
|
||||||
|
special_tokens = {}
|
||||||
|
|
||||||
|
specials = [
|
||||||
|
"<|endoftext|>",
|
||||||
|
"<|startoftranscript|>",
|
||||||
|
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||||
|
"<|translate|>",
|
||||||
|
"<|transcribe|>",
|
||||||
|
"<|startoflm|>",
|
||||||
|
"<|startofprev|>",
|
||||||
|
"<|nospeech|>",
|
||||||
|
"<|notimestamps|>",
|
||||||
|
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||||
|
]
|
||||||
|
|
||||||
|
for token in specials:
|
||||||
|
special_tokens[token] = n_vocab
|
||||||
|
n_vocab += 1
|
||||||
|
|
||||||
|
return tiktoken.Encoding(
|
||||||
|
name=os.path.basename(vocab_path),
|
||||||
|
explicit_n_vocab=n_vocab,
|
||||||
|
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||||
|
mergeable_ranks=ranks,
|
||||||
|
special_tokens=special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_tokenizer(
|
||||||
|
multilingual: bool,
|
||||||
|
*,
|
||||||
|
num_languages: int = 99,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||||
|
) -> Tokenizer:
|
||||||
|
if language is not None:
|
||||||
|
language = language.lower()
|
||||||
|
if language not in LANGUAGES:
|
||||||
|
if language in TO_LANGUAGE_CODE:
|
||||||
|
language = TO_LANGUAGE_CODE[language]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported language: {language}")
|
||||||
|
|
||||||
|
if multilingual:
|
||||||
|
encoding_name = "multilingual"
|
||||||
|
language = language or "en"
|
||||||
|
task = task or "transcribe"
|
||||||
|
else:
|
||||||
|
encoding_name = "gpt2"
|
||||||
|
language = None
|
||||||
|
task = None
|
||||||
|
|
||||||
|
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||||
|
|
||||||
|
return Tokenizer(
|
||||||
|
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||||
|
)
|
||||||
623
whisperlivekit/simul_whisper/whisper/transcribe.py
Normal file
@@ -0,0 +1,623 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
import warnings
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from .audio import (
|
||||||
|
FRAMES_PER_SECOND,
|
||||||
|
HOP_LENGTH,
|
||||||
|
N_FRAMES,
|
||||||
|
N_SAMPLES,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
log_mel_spectrogram,
|
||||||
|
pad_or_trim,
|
||||||
|
)
|
||||||
|
from .decoding import DecodingOptions, DecodingResult
|
||||||
|
from .timing import add_word_timestamps
|
||||||
|
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||||
|
from .utils import (
|
||||||
|
exact_div,
|
||||||
|
format_timestamp,
|
||||||
|
get_end,
|
||||||
|
get_writer,
|
||||||
|
make_safe,
|
||||||
|
optional_float,
|
||||||
|
optional_int,
|
||||||
|
str2bool,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .model import Whisper
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe(
|
||||||
|
model: "Whisper",
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
|
*,
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||||
|
compression_ratio_threshold: Optional[float] = 2.4,
|
||||||
|
logprob_threshold: Optional[float] = -1.0,
|
||||||
|
no_speech_threshold: Optional[float] = 0.6,
|
||||||
|
condition_on_previous_text: bool = True,
|
||||||
|
initial_prompt: Optional[str] = None,
|
||||||
|
carry_initial_prompt: bool = False,
|
||||||
|
word_timestamps: bool = False,
|
||||||
|
prepend_punctuations: str = "\"'“¿([{-",
|
||||||
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
|
clip_timestamps: Union[str, List[float]] = "0",
|
||||||
|
hallucination_silence_threshold: Optional[float] = None,
|
||||||
|
**decode_options,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Transcribe an audio file using Whisper
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model: Whisper
|
||||||
|
The Whisper model instance
|
||||||
|
|
||||||
|
audio: Union[str, np.ndarray, torch.Tensor]
|
||||||
|
The path to the audio file to open, or the audio waveform
|
||||||
|
|
||||||
|
verbose: bool
|
||||||
|
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||||
|
If False, displays minimal details. If None, does not display anything
|
||||||
|
|
||||||
|
temperature: Union[float, Tuple[float, ...]]
|
||||||
|
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
||||||
|
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||||
|
|
||||||
|
compression_ratio_threshold: float
|
||||||
|
If the gzip compression ratio is above this value, treat as failed
|
||||||
|
|
||||||
|
logprob_threshold: float
|
||||||
|
If the average log probability over sampled tokens is below this value, treat as failed
|
||||||
|
|
||||||
|
no_speech_threshold: float
|
||||||
|
If the no_speech probability is higher than this value AND the average log probability
|
||||||
|
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||||
|
|
||||||
|
condition_on_previous_text: bool
|
||||||
|
if True, the previous output of the model is provided as a prompt for the next window;
|
||||||
|
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||||
|
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||||
|
|
||||||
|
word_timestamps: bool
|
||||||
|
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
||||||
|
and include the timestamps for each word in each segment.
|
||||||
|
|
||||||
|
prepend_punctuations: str
|
||||||
|
If word_timestamps is True, merge these punctuation symbols with the next word
|
||||||
|
|
||||||
|
append_punctuations: str
|
||||||
|
If word_timestamps is True, merge these punctuation symbols with the previous word
|
||||||
|
|
||||||
|
initial_prompt: Optional[str]
|
||||||
|
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||||
|
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||||
|
to make it more likely to predict those word correctly.
|
||||||
|
|
||||||
|
carry_initial_prompt: bool
|
||||||
|
If carry_initial_prompt is True, `initial_prompt` is prepended to the prompt of each internal
|
||||||
|
`decode()` call. If there is not enough context space at the start of the prompt, it is
|
||||||
|
left-sliced to make space.
|
||||||
|
|
||||||
|
decode_options: dict
|
||||||
|
Keyword arguments to construct `DecodingOptions` instances
|
||||||
|
|
||||||
|
clip_timestamps: Union[str, List[float]]
|
||||||
|
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||||
|
The last end timestamp defaults to the end of the file.
|
||||||
|
|
||||||
|
hallucination_silence_threshold: Optional[float]
|
||||||
|
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||||
|
when a possible hallucination is detected
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||||
|
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||||
|
"""
|
||||||
|
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||||
|
if model.device == torch.device("cpu"):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||||
|
if dtype == torch.float16:
|
||||||
|
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
decode_options["fp16"] = False
|
||||||
|
|
||||||
|
# Pad 30-seconds of silence to the input audio, for slicing
|
||||||
|
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||||
|
content_frames = mel.shape[-1] - N_FRAMES
|
||||||
|
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
|
||||||
|
if decode_options.get("language", None) is None:
|
||||||
|
if not model.is_multilingual:
|
||||||
|
decode_options["language"] = "en"
|
||||||
|
else:
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||||
|
)
|
||||||
|
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
_, probs = model.detect_language(mel_segment)
|
||||||
|
decode_options["language"] = max(probs, key=probs.get)
|
||||||
|
if verbose is not None:
|
||||||
|
print(
|
||||||
|
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
language: str = decode_options["language"]
|
||||||
|
task: str = decode_options.get("task", "transcribe")
|
||||||
|
tokenizer = get_tokenizer(
|
||||||
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language=language,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(clip_timestamps, str):
|
||||||
|
clip_timestamps = [
|
||||||
|
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||||
|
]
|
||||||
|
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||||
|
if len(seek_points) == 0:
|
||||||
|
seek_points.append(0)
|
||||||
|
if len(seek_points) % 2 == 1:
|
||||||
|
seek_points.append(content_frames)
|
||||||
|
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||||
|
|
||||||
|
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||||
|
|
||||||
|
if word_timestamps and task == "translate":
|
||||||
|
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||||
|
|
||||||
|
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||||
|
temperatures = (
|
||||||
|
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||||
|
)
|
||||||
|
decode_result = None
|
||||||
|
|
||||||
|
for t in temperatures:
|
||||||
|
kwargs = {**decode_options}
|
||||||
|
if t > 0:
|
||||||
|
# disable beam_size and patience when t > 0
|
||||||
|
kwargs.pop("beam_size", None)
|
||||||
|
kwargs.pop("patience", None)
|
||||||
|
else:
|
||||||
|
# disable best_of when t == 0
|
||||||
|
kwargs.pop("best_of", None)
|
||||||
|
|
||||||
|
options = DecodingOptions(**kwargs, temperature=t)
|
||||||
|
decode_result = model.decode(segment, options)
|
||||||
|
|
||||||
|
needs_fallback = False
|
||||||
|
if (
|
||||||
|
compression_ratio_threshold is not None
|
||||||
|
and decode_result.compression_ratio > compression_ratio_threshold
|
||||||
|
):
|
||||||
|
needs_fallback = True # too repetitive
|
||||||
|
if (
|
||||||
|
logprob_threshold is not None
|
||||||
|
and decode_result.avg_logprob < logprob_threshold
|
||||||
|
):
|
||||||
|
needs_fallback = True # average log probability is too low
|
||||||
|
if (
|
||||||
|
no_speech_threshold is not None
|
||||||
|
and decode_result.no_speech_prob > no_speech_threshold
|
||||||
|
and logprob_threshold is not None
|
||||||
|
and decode_result.avg_logprob < logprob_threshold
|
||||||
|
):
|
||||||
|
needs_fallback = False # silence
|
||||||
|
if not needs_fallback:
|
||||||
|
break
|
||||||
|
|
||||||
|
return decode_result
|
||||||
|
|
||||||
|
clip_idx = 0
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
|
input_stride = exact_div(
|
||||||
|
N_FRAMES, model.dims.n_audio_ctx
|
||||||
|
) # mel frames per output token: 2
|
||||||
|
time_precision = (
|
||||||
|
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||||
|
) # time per output token: 0.02 (seconds)
|
||||||
|
all_tokens = []
|
||||||
|
all_segments = []
|
||||||
|
prompt_reset_since = 0
|
||||||
|
|
||||||
|
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
|
||||||
|
if initial_prompt is not None:
|
||||||
|
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||||
|
all_tokens.extend(initial_prompt_tokens)
|
||||||
|
remaining_prompt_length -= len(initial_prompt_tokens)
|
||||||
|
else:
|
||||||
|
initial_prompt_tokens = []
|
||||||
|
|
||||||
|
def new_segment(
|
||||||
|
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||||
|
):
|
||||||
|
tokens = tokens.tolist()
|
||||||
|
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||||
|
return {
|
||||||
|
"seek": seek,
|
||||||
|
"start": start,
|
||||||
|
"end": end,
|
||||||
|
"text": tokenizer.decode(text_tokens),
|
||||||
|
"tokens": tokens,
|
||||||
|
"temperature": result.temperature,
|
||||||
|
"avg_logprob": result.avg_logprob,
|
||||||
|
"compression_ratio": result.compression_ratio,
|
||||||
|
"no_speech_prob": result.no_speech_prob,
|
||||||
|
}
|
||||||
|
|
||||||
|
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||||
|
with tqdm.tqdm(
|
||||||
|
total=content_frames, unit="frames", disable=verbose is not False
|
||||||
|
) as pbar:
|
||||||
|
last_speech_timestamp = 0.0
|
||||||
|
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||||
|
# A later commit should turn this into a simpler nested loop.
|
||||||
|
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||||
|
# while seek < seek_clip_end
|
||||||
|
while clip_idx < len(seek_clips):
|
||||||
|
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||||
|
if seek < seek_clip_start:
|
||||||
|
seek = seek_clip_start
|
||||||
|
if seek >= seek_clip_end:
|
||||||
|
clip_idx += 1
|
||||||
|
if clip_idx < len(seek_clips):
|
||||||
|
seek = seek_clips[clip_idx][0]
|
||||||
|
continue
|
||||||
|
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||||
|
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||||
|
mel_segment = mel[:, seek : seek + segment_size]
|
||||||
|
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||||
|
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||||
|
|
||||||
|
if carry_initial_prompt:
|
||||||
|
nignored = max(len(initial_prompt_tokens), prompt_reset_since)
|
||||||
|
remaining_prompt = all_tokens[nignored:][-remaining_prompt_length:]
|
||||||
|
decode_options["prompt"] = initial_prompt_tokens + remaining_prompt
|
||||||
|
else:
|
||||||
|
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||||
|
|
||||||
|
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||||
|
tokens = torch.tensor(result.tokens)
|
||||||
|
|
||||||
|
if no_speech_threshold is not None:
|
||||||
|
# no voice activity check
|
||||||
|
should_skip = result.no_speech_prob > no_speech_threshold
|
||||||
|
if (
|
||||||
|
logprob_threshold is not None
|
||||||
|
and result.avg_logprob > logprob_threshold
|
||||||
|
):
|
||||||
|
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||||
|
should_skip = False
|
||||||
|
|
||||||
|
if should_skip:
|
||||||
|
seek += segment_size # fast-forward to the next segment boundary
|
||||||
|
continue
|
||||||
|
|
||||||
|
previous_seek = seek
|
||||||
|
current_segments = []
|
||||||
|
|
||||||
|
# anomalous words are very long/short/improbable
|
||||||
|
def word_anomaly_score(word: dict) -> float:
|
||||||
|
probability = word.get("probability", 0.0)
|
||||||
|
duration = word["end"] - word["start"]
|
||||||
|
score = 0.0
|
||||||
|
if probability < 0.15:
|
||||||
|
score += 1.0
|
||||||
|
if duration < 0.133:
|
||||||
|
score += (0.133 - duration) * 15
|
||||||
|
if duration > 2.0:
|
||||||
|
score += duration - 2.0
|
||||||
|
return score
|
||||||
|
|
||||||
|
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||||
|
if segment is None or not segment["words"]:
|
||||||
|
return False
|
||||||
|
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||||
|
words = words[:8]
|
||||||
|
score = sum(word_anomaly_score(w) for w in words)
|
||||||
|
return score >= 3 or score + 0.01 >= len(words)
|
||||||
|
|
||||||
|
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||||
|
return next((s for s in segments if s["words"]), None)
|
||||||
|
|
||||||
|
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||||
|
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||||
|
|
||||||
|
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||||
|
consecutive.add_(1)
|
||||||
|
if len(consecutive) > 0:
|
||||||
|
# if the output contains two consecutive timestamp tokens
|
||||||
|
slices = consecutive.tolist()
|
||||||
|
if single_timestamp_ending:
|
||||||
|
slices.append(len(tokens))
|
||||||
|
|
||||||
|
last_slice = 0
|
||||||
|
for current_slice in slices:
|
||||||
|
sliced_tokens = tokens[last_slice:current_slice]
|
||||||
|
start_timestamp_pos = (
|
||||||
|
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
end_timestamp_pos = (
|
||||||
|
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
current_segments.append(
|
||||||
|
new_segment(
|
||||||
|
start=time_offset + start_timestamp_pos * time_precision,
|
||||||
|
end=time_offset + end_timestamp_pos * time_precision,
|
||||||
|
tokens=sliced_tokens,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
last_slice = current_slice
|
||||||
|
|
||||||
|
if single_timestamp_ending:
|
||||||
|
# single timestamp at the end means no speech after the last timestamp.
|
||||||
|
seek += segment_size
|
||||||
|
else:
|
||||||
|
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||||
|
last_timestamp_pos = (
|
||||||
|
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
seek += last_timestamp_pos * input_stride
|
||||||
|
else:
|
||||||
|
duration = segment_duration
|
||||||
|
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||||
|
if (
|
||||||
|
len(timestamps) > 0
|
||||||
|
and timestamps[-1].item() != tokenizer.timestamp_begin
|
||||||
|
):
|
||||||
|
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||||
|
last_timestamp_pos = (
|
||||||
|
timestamps[-1].item() - tokenizer.timestamp_begin
|
||||||
|
)
|
||||||
|
duration = last_timestamp_pos * time_precision
|
||||||
|
|
||||||
|
current_segments.append(
|
||||||
|
new_segment(
|
||||||
|
start=time_offset,
|
||||||
|
end=time_offset + duration,
|
||||||
|
tokens=tokens,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seek += segment_size
|
||||||
|
|
||||||
|
if word_timestamps:
|
||||||
|
add_word_timestamps(
|
||||||
|
segments=current_segments,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
mel=mel_segment,
|
||||||
|
num_frames=segment_size,
|
||||||
|
prepend_punctuations=prepend_punctuations,
|
||||||
|
append_punctuations=append_punctuations,
|
||||||
|
last_speech_timestamp=last_speech_timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not single_timestamp_ending:
|
||||||
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
|
|
||||||
|
# skip silence before possible hallucinations
|
||||||
|
if hallucination_silence_threshold is not None:
|
||||||
|
threshold = hallucination_silence_threshold
|
||||||
|
if not single_timestamp_ending:
|
||||||
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None and last_word_end > time_offset:
|
||||||
|
remaining_duration = window_end_time - last_word_end
|
||||||
|
if remaining_duration > threshold:
|
||||||
|
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||||
|
else:
|
||||||
|
seek = previous_seek + segment_size
|
||||||
|
|
||||||
|
# if first segment might be a hallucination, skip leading silence
|
||||||
|
first_segment = next_words_segment(current_segments)
|
||||||
|
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||||
|
gap = first_segment["start"] - time_offset
|
||||||
|
if gap > threshold:
|
||||||
|
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip silence before any possible hallucination that is surrounded
|
||||||
|
# by silence or more hallucinations
|
||||||
|
hal_last_end = last_speech_timestamp
|
||||||
|
for si in range(len(current_segments)):
|
||||||
|
segment = current_segments[si]
|
||||||
|
if not segment["words"]:
|
||||||
|
continue
|
||||||
|
if is_segment_anomaly(segment):
|
||||||
|
next_segment = next_words_segment(
|
||||||
|
current_segments[si + 1 :]
|
||||||
|
)
|
||||||
|
if next_segment is not None:
|
||||||
|
hal_next_start = next_segment["words"][0]["start"]
|
||||||
|
else:
|
||||||
|
hal_next_start = time_offset + segment_duration
|
||||||
|
silence_before = (
|
||||||
|
segment["start"] - hal_last_end > threshold
|
||||||
|
or segment["start"] < threshold
|
||||||
|
or segment["start"] - time_offset < 2.0
|
||||||
|
)
|
||||||
|
silence_after = (
|
||||||
|
hal_next_start - segment["end"] > threshold
|
||||||
|
or is_segment_anomaly(next_segment)
|
||||||
|
or window_end_time - segment["end"] < 2.0
|
||||||
|
)
|
||||||
|
if silence_before and silence_after:
|
||||||
|
seek = round(
|
||||||
|
max(time_offset + 1, segment["start"])
|
||||||
|
* FRAMES_PER_SECOND
|
||||||
|
)
|
||||||
|
if content_duration - segment["end"] < threshold:
|
||||||
|
seek = content_frames
|
||||||
|
current_segments[si:] = []
|
||||||
|
break
|
||||||
|
hal_last_end = segment["end"]
|
||||||
|
|
||||||
|
last_word_end = get_end(current_segments)
|
||||||
|
if last_word_end is not None:
|
||||||
|
last_speech_timestamp = last_word_end
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
for segment in current_segments:
|
||||||
|
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||||
|
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||||
|
print(make_safe(line))
|
||||||
|
|
||||||
|
# if a segment is instantaneous or does not contain text, clear it
|
||||||
|
for i, segment in enumerate(current_segments):
|
||||||
|
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
||||||
|
segment["text"] = ""
|
||||||
|
segment["tokens"] = []
|
||||||
|
segment["words"] = []
|
||||||
|
|
||||||
|
all_segments.extend(
|
||||||
|
[
|
||||||
|
{"id": i, **segment}
|
||||||
|
for i, segment in enumerate(
|
||||||
|
current_segments, start=len(all_segments)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
all_tokens.extend(
|
||||||
|
[token for segment in current_segments for token in segment["tokens"]]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not condition_on_previous_text or result.temperature > 0.5:
|
||||||
|
# do not feed the prompt tokens if a high temperature was used
|
||||||
|
prompt_reset_since = len(all_tokens)
|
||||||
|
|
||||||
|
# update progress bar
|
||||||
|
pbar.update(min(content_frames, seek) - previous_seek)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||||
|
segments=all_segments,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cli():
|
||||||
|
from . import available_models
|
||||||
|
|
||||||
|
def valid_model_name(name):
|
||||||
|
if name in available_models() or os.path.exists(name):
|
||||||
|
return name
|
||||||
|
raise ValueError(
|
||||||
|
f"model should be one of {available_models()} or path to a model checkpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||||
|
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
|
||||||
|
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||||
|
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||||
|
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||||
|
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||||
|
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||||
|
|
||||||
|
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||||
|
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||||
|
|
||||||
|
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||||
|
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||||
|
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||||
|
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||||
|
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||||
|
|
||||||
|
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||||
|
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||||
|
parser.add_argument("--carry_initial_prompt", type=str2bool, default=False, help="if True, prepend initial_prompt to every internal decode() call. May reduce the effectiveness of condition_on_previous_text")
|
||||||
|
|
||||||
|
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||||
|
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||||
|
|
||||||
|
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||||
|
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||||
|
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||||
|
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||||
|
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||||
|
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||||
|
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||||
|
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||||
|
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||||
|
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||||
|
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
||||||
|
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||||
|
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||||
|
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
args = parser.parse_args().__dict__
|
||||||
|
model_name: str = args.pop("model")
|
||||||
|
model_dir: str = args.pop("model_dir")
|
||||||
|
output_dir: str = args.pop("output_dir")
|
||||||
|
output_format: str = args.pop("output_format")
|
||||||
|
device: str = args.pop("device")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||||
|
if args["language"] is not None:
|
||||||
|
warnings.warn(
|
||||||
|
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||||
|
)
|
||||||
|
args["language"] = "en"
|
||||||
|
|
||||||
|
temperature = args.pop("temperature")
|
||||||
|
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||||
|
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
||||||
|
else:
|
||||||
|
temperature = [temperature]
|
||||||
|
|
||||||
|
if (threads := args.pop("threads")) > 0:
|
||||||
|
torch.set_num_threads(threads)
|
||||||
|
|
||||||
|
from . import load_model
|
||||||
|
|
||||||
|
model = load_model(model_name, device=device, download_root=model_dir)
|
||||||
|
|
||||||
|
writer = get_writer(output_format, output_dir)
|
||||||
|
word_options = [
|
||||||
|
"highlight_words",
|
||||||
|
"max_line_count",
|
||||||
|
"max_line_width",
|
||||||
|
"max_words_per_line",
|
||||||
|
]
|
||||||
|
if not args["word_timestamps"]:
|
||||||
|
for option in word_options:
|
||||||
|
if args[option]:
|
||||||
|
parser.error(f"--{option} requires --word_timestamps True")
|
||||||
|
if args["max_line_count"] and not args["max_line_width"]:
|
||||||
|
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||||
|
if args["max_words_per_line"] and args["max_line_width"]:
|
||||||
|
warnings.warn("--max_words_per_line has no effect with --max_line_width")
|
||||||
|
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||||
|
for audio_path in args.pop("audio"):
|
||||||
|
try:
|
||||||
|
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||||
|
writer(result, audio_path, **writer_args)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
||||||
117
whisperlivekit/simul_whisper/whisper/triton_ops.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError("triton import failed; try `pip install --pre triton`")
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dtw_kernel(
|
||||||
|
cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
|
||||||
|
):
|
||||||
|
offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < M
|
||||||
|
|
||||||
|
for k in range(1, N + M + 1): # k = i + j
|
||||||
|
tl.debug_barrier()
|
||||||
|
|
||||||
|
p0 = cost + (k - 1) * cost_stride
|
||||||
|
p1 = cost + k * cost_stride
|
||||||
|
p2 = cost + k * cost_stride + 1
|
||||||
|
|
||||||
|
c0 = tl.load(p0 + offsets, mask=mask)
|
||||||
|
c1 = tl.load(p1 + offsets, mask=mask)
|
||||||
|
c2 = tl.load(p2 + offsets, mask=mask)
|
||||||
|
|
||||||
|
x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
|
||||||
|
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
|
||||||
|
|
||||||
|
cost_ptr = cost + (k + 1) * cost_stride + 1
|
||||||
|
tl.store(cost_ptr + offsets, cost_row, mask=mask)
|
||||||
|
|
||||||
|
trace_ptr = trace + (k + 1) * trace_stride + 1
|
||||||
|
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
|
||||||
|
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
|
||||||
|
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def median_kernel(filter_width: int):
|
||||||
|
@triton.jit
|
||||||
|
def kernel(
|
||||||
|
y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
|
||||||
|
): # x.shape[-1] == filter_width
|
||||||
|
row_idx = tl.program_id(0)
|
||||||
|
offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < y_stride
|
||||||
|
|
||||||
|
x_ptr = x + row_idx * x_stride # noqa: F841
|
||||||
|
y_ptr = y + row_idx * y_stride
|
||||||
|
|
||||||
|
LOAD_ALL_ROWS_HERE # noqa: F821
|
||||||
|
|
||||||
|
BUBBLESORT_HERE # noqa: F821
|
||||||
|
|
||||||
|
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
||||||
|
|
||||||
|
kernel = triton.JITFunction(kernel.fn)
|
||||||
|
new_kernel = kernel.src.replace(
|
||||||
|
" LOAD_ALL_ROWS_HERE",
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
|
||||||
|
for i in range(filter_width)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
new_kernel = new_kernel.replace(
|
||||||
|
" BUBBLESORT_HERE",
|
||||||
|
"\n\n".join(
|
||||||
|
[
|
||||||
|
"\n\n".join(
|
||||||
|
[
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
|
||||||
|
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
|
||||||
|
f" row{j} = smaller",
|
||||||
|
f" row{j + 1} = larger",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for j in range(filter_width - i - 1)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for i in range(filter_width // 2 + 1)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
new_kernel = new_kernel.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
||||||
|
|
||||||
|
if hasattr(kernel, "_unsafe_update_src") is True:
|
||||||
|
kernel._unsafe_update_src(new_kernel)
|
||||||
|
kernel.hash = None
|
||||||
|
else:
|
||||||
|
kernel.src = new_kernel
|
||||||
|
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
|
def median_filter_cuda(x: torch.Tensor, filter_width: int):
|
||||||
|
"""Apply a median filter of given width along the last dimension of x"""
|
||||||
|
slices = x.contiguous().unfold(-1, filter_width, 1)
|
||||||
|
grid = np.prod(slices.shape[:-2])
|
||||||
|
|
||||||
|
kernel = median_kernel(filter_width)
|
||||||
|
y = torch.empty_like(slices[..., 0])
|
||||||
|
|
||||||
|
BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
|
||||||
|
kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
|
||||||
|
|
||||||
|
return y
|
||||||
318
whisperlivekit/simul_whisper/whisper/utils.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import zlib
|
||||||
|
from typing import Callable, List, Optional, TextIO
|
||||||
|
|
||||||
|
system_encoding = sys.getdefaultencoding()
|
||||||
|
|
||||||
|
if system_encoding != "utf-8":
|
||||||
|
|
||||||
|
def make_safe(string):
|
||||||
|
# replaces any character not representable using the system default encoding with an '?',
|
||||||
|
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
||||||
|
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def make_safe(string):
|
||||||
|
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def exact_div(x, y):
|
||||||
|
assert x % y == 0
|
||||||
|
return x // y
|
||||||
|
|
||||||
|
|
||||||
|
def str2bool(string):
|
||||||
|
str2val = {"True": True, "False": False}
|
||||||
|
if string in str2val:
|
||||||
|
return str2val[string]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||||
|
|
||||||
|
|
||||||
|
def optional_int(string):
|
||||||
|
return None if string == "None" else int(string)
|
||||||
|
|
||||||
|
|
||||||
|
def optional_float(string):
|
||||||
|
return None if string == "None" else float(string)
|
||||||
|
|
||||||
|
|
||||||
|
def compression_ratio(text) -> float:
|
||||||
|
text_bytes = text.encode("utf-8")
|
||||||
|
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||||
|
|
||||||
|
|
||||||
|
def format_timestamp(
|
||||||
|
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
||||||
|
):
|
||||||
|
assert seconds >= 0, "non-negative timestamp expected"
|
||||||
|
milliseconds = round(seconds * 1000.0)
|
||||||
|
|
||||||
|
hours = milliseconds // 3_600_000
|
||||||
|
milliseconds -= hours * 3_600_000
|
||||||
|
|
||||||
|
minutes = milliseconds // 60_000
|
||||||
|
milliseconds -= minutes * 60_000
|
||||||
|
|
||||||
|
seconds = milliseconds // 1_000
|
||||||
|
milliseconds -= seconds * 1_000
|
||||||
|
|
||||||
|
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||||
|
return (
|
||||||
|
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_start(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["start"] for s in segments for w in s["words"]),
|
||||||
|
segments[0]["start"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_end(segments: List[dict]) -> Optional[float]:
|
||||||
|
return next(
|
||||||
|
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||||
|
segments[-1]["end"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResultWriter:
|
||||||
|
extension: str
|
||||||
|
|
||||||
|
def __init__(self, output_dir: str):
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
audio_basename = os.path.basename(audio_path)
|
||||||
|
audio_basename = os.path.splitext(audio_basename)[0]
|
||||||
|
output_path = os.path.join(
|
||||||
|
self.output_dir, audio_basename + "." + self.extension
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
self.write_result(result, file=f, options=options, **kwargs)
|
||||||
|
|
||||||
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class WriteTXT(ResultWriter):
|
||||||
|
extension: str = "txt"
|
||||||
|
|
||||||
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
for segment in result["segments"]:
|
||||||
|
print(segment["text"].strip(), file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SubtitlesWriter(ResultWriter):
|
||||||
|
always_include_hours: bool
|
||||||
|
decimal_marker: str
|
||||||
|
|
||||||
|
def iterate_result(
|
||||||
|
self,
|
||||||
|
result: dict,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
*,
|
||||||
|
max_line_width: Optional[int] = None,
|
||||||
|
max_line_count: Optional[int] = None,
|
||||||
|
highlight_words: bool = False,
|
||||||
|
max_words_per_line: Optional[int] = None,
|
||||||
|
):
|
||||||
|
options = options or {}
|
||||||
|
max_line_width = max_line_width or options.get("max_line_width")
|
||||||
|
max_line_count = max_line_count or options.get("max_line_count")
|
||||||
|
highlight_words = highlight_words or options.get("highlight_words", False)
|
||||||
|
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
||||||
|
preserve_segments = max_line_count is None or max_line_width is None
|
||||||
|
max_line_width = max_line_width or 1000
|
||||||
|
max_words_per_line = max_words_per_line or 1000
|
||||||
|
|
||||||
|
def iterate_subtitles():
|
||||||
|
line_len = 0
|
||||||
|
line_count = 1
|
||||||
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
|
subtitle: List[dict] = []
|
||||||
|
last: float = get_start(result["segments"]) or 0.0
|
||||||
|
for segment in result["segments"]:
|
||||||
|
chunk_index = 0
|
||||||
|
words_count = max_words_per_line
|
||||||
|
while chunk_index < len(segment["words"]):
|
||||||
|
remaining_words = len(segment["words"]) - chunk_index
|
||||||
|
if max_words_per_line > len(segment["words"]) - chunk_index:
|
||||||
|
words_count = remaining_words
|
||||||
|
for i, original_timing in enumerate(
|
||||||
|
segment["words"][chunk_index : chunk_index + words_count]
|
||||||
|
):
|
||||||
|
timing = original_timing.copy()
|
||||||
|
long_pause = (
|
||||||
|
not preserve_segments and timing["start"] - last > 3.0
|
||||||
|
)
|
||||||
|
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||||
|
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||||
|
if (
|
||||||
|
line_len > 0
|
||||||
|
and has_room
|
||||||
|
and not long_pause
|
||||||
|
and not seg_break
|
||||||
|
):
|
||||||
|
# line continuation
|
||||||
|
line_len += len(timing["word"])
|
||||||
|
else:
|
||||||
|
# new line
|
||||||
|
timing["word"] = timing["word"].strip()
|
||||||
|
if (
|
||||||
|
len(subtitle) > 0
|
||||||
|
and max_line_count is not None
|
||||||
|
and (long_pause or line_count >= max_line_count)
|
||||||
|
or seg_break
|
||||||
|
):
|
||||||
|
# subtitle break
|
||||||
|
yield subtitle
|
||||||
|
subtitle = []
|
||||||
|
line_count = 1
|
||||||
|
elif line_len > 0:
|
||||||
|
# line break
|
||||||
|
line_count += 1
|
||||||
|
timing["word"] = "\n" + timing["word"]
|
||||||
|
line_len = len(timing["word"].strip())
|
||||||
|
subtitle.append(timing)
|
||||||
|
last = timing["start"]
|
||||||
|
chunk_index += max_words_per_line
|
||||||
|
if len(subtitle) > 0:
|
||||||
|
yield subtitle
|
||||||
|
|
||||||
|
if len(result["segments"]) > 0 and "words" in result["segments"][0]:
|
||||||
|
for subtitle in iterate_subtitles():
|
||||||
|
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
||||||
|
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
||||||
|
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||||
|
if highlight_words:
|
||||||
|
last = subtitle_start
|
||||||
|
all_words = [timing["word"] for timing in subtitle]
|
||||||
|
for i, this_word in enumerate(subtitle):
|
||||||
|
start = self.format_timestamp(this_word["start"])
|
||||||
|
end = self.format_timestamp(this_word["end"])
|
||||||
|
if last != start:
|
||||||
|
yield last, start, subtitle_text
|
||||||
|
|
||||||
|
yield start, end, "".join(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||||
|
if j == i
|
||||||
|
else word
|
||||||
|
)
|
||||||
|
for j, word in enumerate(all_words)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
last = end
|
||||||
|
else:
|
||||||
|
yield subtitle_start, subtitle_end, subtitle_text
|
||||||
|
else:
|
||||||
|
for segment in result["segments"]:
|
||||||
|
segment_start = self.format_timestamp(segment["start"])
|
||||||
|
segment_end = self.format_timestamp(segment["end"])
|
||||||
|
segment_text = segment["text"].strip().replace("-->", "->")
|
||||||
|
yield segment_start, segment_end, segment_text
|
||||||
|
|
||||||
|
def format_timestamp(self, seconds: float):
|
||||||
|
return format_timestamp(
|
||||||
|
seconds=seconds,
|
||||||
|
always_include_hours=self.always_include_hours,
|
||||||
|
decimal_marker=self.decimal_marker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteVTT(SubtitlesWriter):
|
||||||
|
extension: str = "vtt"
|
||||||
|
always_include_hours: bool = False
|
||||||
|
decimal_marker: str = "."
|
||||||
|
|
||||||
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
print("WEBVTT\n", file=file)
|
||||||
|
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||||
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteSRT(SubtitlesWriter):
|
||||||
|
extension: str = "srt"
|
||||||
|
always_include_hours: bool = True
|
||||||
|
decimal_marker: str = ","
|
||||||
|
|
||||||
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
for i, (start, end, text) in enumerate(
|
||||||
|
self.iterate_result(result, options, **kwargs), start=1
|
||||||
|
):
|
||||||
|
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteTSV(ResultWriter):
|
||||||
|
"""
|
||||||
|
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
||||||
|
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
||||||
|
|
||||||
|
Using integer milliseconds as start and end times means there's no chance of interference from
|
||||||
|
an environment setting a language encoding that causes the decimal in a floating point number
|
||||||
|
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
||||||
|
"""
|
||||||
|
|
||||||
|
extension: str = "tsv"
|
||||||
|
|
||||||
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
print("start", "end", "text", sep="\t", file=file)
|
||||||
|
for segment in result["segments"]:
|
||||||
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||||
|
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||||
|
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteJSON(ResultWriter):
|
||||||
|
extension: str = "json"
|
||||||
|
|
||||||
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
json.dump(result, file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_writer(
|
||||||
|
output_format: str, output_dir: str
|
||||||
|
) -> Callable[[dict, TextIO, dict], None]:
|
||||||
|
writers = {
|
||||||
|
"txt": WriteTXT,
|
||||||
|
"vtt": WriteVTT,
|
||||||
|
"srt": WriteSRT,
|
||||||
|
"tsv": WriteTSV,
|
||||||
|
"json": WriteJSON,
|
||||||
|
}
|
||||||
|
|
||||||
|
if output_format == "all":
|
||||||
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||||
|
|
||||||
|
def write_all(
|
||||||
|
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
for writer in all_writers:
|
||||||
|
writer(result, file, options, **kwargs)
|
||||||
|
|
||||||
|
return write_all
|
||||||
|
|
||||||
|
return writers[output_format](output_dir)
|
||||||
1
whisperlivekit/simul_whisper/whisper/version.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__version__ = "20250625"
|
||||||
@@ -1,20 +1,57 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, Any, List
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||||
|
|
||||||
|
def format_time(seconds: float) -> str:
|
||||||
|
"""Format seconds as HH:MM:SS."""
|
||||||
|
return str(timedelta(seconds=int(seconds)))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TimedText:
|
class TimedText:
|
||||||
start: Optional[float]
|
start: Optional[float] = 0
|
||||||
end: Optional[float]
|
end: Optional[float] = 0
|
||||||
text: Optional[str] = ''
|
text: Optional[str] = ''
|
||||||
speaker: Optional[int] = -1
|
speaker: Optional[int] = -1
|
||||||
probability: Optional[float] = None
|
probability: Optional[float] = None
|
||||||
is_dummy: Optional[bool] = False
|
is_dummy: Optional[bool] = False
|
||||||
|
detected_language: Optional[str] = None
|
||||||
|
|
||||||
|
def is_punctuation(self):
|
||||||
|
return self.text.strip() in PUNCTUATION_MARKS
|
||||||
|
|
||||||
|
def overlaps_with(self, other: 'TimedText') -> bool:
|
||||||
|
return not (self.end <= other.start or other.end <= self.start)
|
||||||
|
|
||||||
|
def is_within(self, other: 'TimedText') -> bool:
|
||||||
|
return other.contains_timespan(self)
|
||||||
|
|
||||||
@dataclass
|
def duration(self) -> float:
|
||||||
|
return self.end - self.start
|
||||||
|
|
||||||
|
def contains_time(self, time: float) -> bool:
|
||||||
|
return self.start <= time <= self.end
|
||||||
|
|
||||||
|
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||||
|
return self.start <= other.start and self.end >= other.end
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return bool(self.text)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
class ASRToken(TimedText):
|
class ASRToken(TimedText):
|
||||||
|
|
||||||
|
corrected_speaker: Optional[int] = -1
|
||||||
|
validated_speaker: bool = False
|
||||||
|
validated_text: bool = False
|
||||||
|
validated_language: bool = False
|
||||||
|
|
||||||
def with_offset(self, offset: float) -> "ASRToken":
|
def with_offset(self, offset: float) -> "ASRToken":
|
||||||
"""Return a new token with the time offset added."""
|
"""Return a new token with the time offset added."""
|
||||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
|
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability, detected_language=self.detected_language)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Sentence(TimedText):
|
class Sentence(TimedText):
|
||||||
@@ -22,11 +59,128 @@ class Sentence(TimedText):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Transcript(TimedText):
|
class Transcript(TimedText):
|
||||||
pass
|
"""
|
||||||
|
represents a concatenation of several ASRToken
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tokens(
|
||||||
|
cls,
|
||||||
|
tokens: List[ASRToken],
|
||||||
|
sep: Optional[str] = None,
|
||||||
|
offset: float = 0
|
||||||
|
) -> "Transcript":
|
||||||
|
sep = sep if sep is not None else ' '
|
||||||
|
text = sep.join(token.text for token in tokens)
|
||||||
|
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||||
|
if tokens:
|
||||||
|
start = offset + tokens[0].start
|
||||||
|
end = offset + tokens[-1].end
|
||||||
|
else:
|
||||||
|
start = None
|
||||||
|
end = None
|
||||||
|
return cls(start, end, text, probability=probability)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpeakerSegment(TimedText):
|
class SpeakerSegment(TimedText):
|
||||||
"""Represents a segment of audio attributed to a specific speaker.
|
"""Represents a segment of audio attributed to a specific speaker.
|
||||||
No text nor probability is associated with this segment.
|
No text nor probability is associated with this segment.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Translation(TimedText):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def approximate_cut_at(self, cut_time):
|
||||||
|
"""
|
||||||
|
Each word in text is considered to be of duration (end-start)/len(words in text)
|
||||||
|
"""
|
||||||
|
if not self.text or not self.contains_time(cut_time):
|
||||||
|
return self, None
|
||||||
|
|
||||||
|
words = self.text.split()
|
||||||
|
num_words = len(words)
|
||||||
|
if num_words == 0:
|
||||||
|
return self, None
|
||||||
|
|
||||||
|
duration_per_word = self.duration() / num_words
|
||||||
|
|
||||||
|
cut_word_index = int((cut_time - self.start) / duration_per_word)
|
||||||
|
|
||||||
|
if cut_word_index >= num_words:
|
||||||
|
cut_word_index = num_words -1
|
||||||
|
|
||||||
|
text0 = " ".join(words[:cut_word_index])
|
||||||
|
text1 = " ".join(words[cut_word_index:])
|
||||||
|
|
||||||
|
segment0 = Translation(start=self.start, end=cut_time, text=text0)
|
||||||
|
segment1 = Translation(start=cut_time, end=self.end, text=text1)
|
||||||
|
|
||||||
|
return segment0, segment1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Silence():
|
||||||
|
duration: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Line(TimedText):
|
||||||
|
translation: str = ''
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
_dict = {
|
||||||
|
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||||
|
'text': self.text,
|
||||||
|
'start': format_time(self.start),
|
||||||
|
'end': format_time(self.end),
|
||||||
|
}
|
||||||
|
if self.translation:
|
||||||
|
_dict['translation'] = self.translation
|
||||||
|
if self.detected_language:
|
||||||
|
_dict['detected_language'] = self.detected_language
|
||||||
|
return _dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FrontData():
|
||||||
|
status: str = ''
|
||||||
|
error: str = ''
|
||||||
|
lines: list[Line] = field(default_factory=list)
|
||||||
|
buffer_transcription: str = ''
|
||||||
|
buffer_diarization: str = ''
|
||||||
|
remaining_time_transcription: float = 0.
|
||||||
|
remaining_time_diarization: float = 0.
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
_dict = {
|
||||||
|
'status': self.status,
|
||||||
|
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
||||||
|
'buffer_transcription': self.buffer_transcription,
|
||||||
|
'buffer_diarization': self.buffer_diarization,
|
||||||
|
'remaining_time_transcription': self.remaining_time_transcription,
|
||||||
|
'remaining_time_diarization': self.remaining_time_diarization,
|
||||||
|
}
|
||||||
|
if self.error:
|
||||||
|
_dict['error'] = self.error
|
||||||
|
return _dict
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChangeSpeaker:
|
||||||
|
speaker: int
|
||||||
|
start: int
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class State():
|
||||||
|
tokens: list = field(default_factory=list)
|
||||||
|
last_validated_token: int = 0
|
||||||
|
translation_validated_segments: list = field(default_factory=list)
|
||||||
|
translation_buffer: list = field(default_factory=list)
|
||||||
|
buffer_transcription: str = field(default_factory=Transcript)
|
||||||
|
end_buffer: float = 0.0
|
||||||
|
end_attributed_speaker: float = 0.0
|
||||||
|
remaining_time_transcription: float = 0.0
|
||||||
|
remaining_time_diarization: float = 0.0
|
||||||
|
beg_loop: Optional[int] = None
|
||||||
60
whisperlivekit/trail_repetition.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from typing import Sequence, Callable, Any, Optional, Dict
|
||||||
|
|
||||||
|
def _detect_tail_repetition(
|
||||||
|
seq: Sequence[Any],
|
||||||
|
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
|
||||||
|
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
|
||||||
|
max_tail: int = 300, # search window from the end for speed
|
||||||
|
prefer: str = "longest", # "longest" coverage or "smallest" block
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
vals = [key(x) for x in seq][-max_tail:]
|
||||||
|
n = len(vals)
|
||||||
|
best = None
|
||||||
|
|
||||||
|
# try every possible block length
|
||||||
|
for b in range(min_block, n // 2 + 1):
|
||||||
|
block = vals[-b:]
|
||||||
|
# count how many times this block repeats contiguously at the very end
|
||||||
|
count, i = 0, n
|
||||||
|
while i - b >= 0 and vals[i - b:i] == block:
|
||||||
|
count += 1
|
||||||
|
i -= b
|
||||||
|
|
||||||
|
if count >= 2:
|
||||||
|
cand = {
|
||||||
|
"block_size": b,
|
||||||
|
"count": count,
|
||||||
|
"start_index": len(seq) - count * b, # in original seq
|
||||||
|
"end_index": len(seq),
|
||||||
|
}
|
||||||
|
if (best is None or
|
||||||
|
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
|
||||||
|
(prefer == "smallest" and b < best["block_size"])):
|
||||||
|
best = cand
|
||||||
|
return best
|
||||||
|
|
||||||
|
def trim_tail_repetition(
|
||||||
|
seq: Sequence[Any],
|
||||||
|
key: Callable[[Any], Any] = lambda x: x,
|
||||||
|
min_block: int = 1,
|
||||||
|
max_tail: int = 300,
|
||||||
|
prefer: str = "longest",
|
||||||
|
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns a new sequence with repeated tail trimmed.
|
||||||
|
keep=1 -> keep a single copy of the repeated block.
|
||||||
|
keep=0 -> remove all copies of the repeated block.
|
||||||
|
"""
|
||||||
|
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
|
||||||
|
if not rep:
|
||||||
|
return seq, False # nothing to trim
|
||||||
|
|
||||||
|
b, c = rep["block_size"], rep["count"]
|
||||||
|
if keep < 0:
|
||||||
|
keep = 0
|
||||||
|
if keep >= c:
|
||||||
|
return seq, False # nothing to trim (already <= keep copies)
|
||||||
|
# new length = total - (copies_to_remove * block_size)
|
||||||
|
new_len = len(seq) - (c - keep) * b
|
||||||
|
return seq[:new_len], True
|
||||||
0
whisperlivekit/vad_models/__init__.py
Normal file
BIN
whisperlivekit/vad_models/silero_vad.jit
Normal file
BIN
whisperlivekit/vad_models/silero_vad.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_16k_op15.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_half.onnx
Normal file
51
whisperlivekit/warmup.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def load_file(warmup_file=None, timeout=5):
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import urllib.request
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
if warmup_file == "":
|
||||||
|
logger.info(f"Skipping warmup.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Download JFK sample if not already present
|
||||||
|
if warmup_file is None:
|
||||||
|
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
||||||
|
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||||
|
try:
|
||||||
|
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||||
|
with urllib.request.urlopen(jfk_url, timeout=timeout) as r, open(warmup_file, "wb") as f:
|
||||||
|
f.write(r.read())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Warmup file download failed: {e}.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate file and load
|
||||||
|
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||||
|
logger.warning(f"Warmup file {warmup_file} is invalid or missing.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio, _ = librosa.load(warmup_file, sr=16000)
|
||||||
|
return audio
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load warmup file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||||
|
"""
|
||||||
|
Warmup the ASR model by transcribing a short audio file.
|
||||||
|
"""
|
||||||
|
audio = load_file(warmup_file=warmup_file, timeout=timeout)
|
||||||
|
if audio is None:
|
||||||
|
logger.warning("Warmup file unavailable. Skipping ASR warmup.")
|
||||||
|
return
|
||||||
|
asr.transcribe(audio)
|
||||||
|
logger.info("ASR model is warmed up.")
|
||||||
625
whisperlivekit/web/live_transcription.css
Normal file
@@ -0,0 +1,625 @@
|
|||||||
|
:root {
|
||||||
|
--bg: #ffffff;
|
||||||
|
--text: #111111;
|
||||||
|
--muted: #666666;
|
||||||
|
--border: #e5e5e5;
|
||||||
|
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||||
|
--chip-text: #000000;
|
||||||
|
--spinner-border: #8d8d8d5c;
|
||||||
|
--spinner-top: #b0b0b0;
|
||||||
|
--silence-bg: #f3f3f3;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||||
|
--button-bg: #ffffff;
|
||||||
|
--button-border: #e9e9e9;
|
||||||
|
--wave-stroke: #000000;
|
||||||
|
--label-dia-text: #868686;
|
||||||
|
--label-trans-text: #111111;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (prefers-color-scheme: dark) {
|
||||||
|
:root:not([data-theme="light"]) {
|
||||||
|
--bg: #0b0b0b;
|
||||||
|
--text: #e6e6e6;
|
||||||
|
--muted: #9aa0a6;
|
||||||
|
--border: #333333;
|
||||||
|
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||||
|
--chip-text: #e6e6e6;
|
||||||
|
--spinner-border: #555555;
|
||||||
|
--spinner-top: #dddddd;
|
||||||
|
--silence-bg: #1a1a1a;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||||
|
--button-bg: #111111;
|
||||||
|
--button-border: #333333;
|
||||||
|
--wave-stroke: #e6e6e6;
|
||||||
|
--label-dia-text: #b3b3b3;
|
||||||
|
--label-trans-text: #ffffff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
:root[data-theme="dark"] {
|
||||||
|
--bg: #0b0b0b;
|
||||||
|
--text: #e6e6e6;
|
||||||
|
--muted: #9aa0a6;
|
||||||
|
--border: #333333;
|
||||||
|
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||||
|
--chip-text: #e6e6e6;
|
||||||
|
--spinner-border: #555555;
|
||||||
|
--spinner-top: #dddddd;
|
||||||
|
--silence-bg: #1a1a1a;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||||
|
--button-bg: #111111;
|
||||||
|
--button-border: #333333;
|
||||||
|
--wave-stroke: #e6e6e6;
|
||||||
|
--label-dia-text: #b3b3b3;
|
||||||
|
--label-trans-text: #ffffff;
|
||||||
|
}
|
||||||
|
|
||||||
|
:root[data-theme="light"] {
|
||||||
|
--bg: #ffffff;
|
||||||
|
--text: #111111;
|
||||||
|
--muted: #666666;
|
||||||
|
--border: #e5e5e5;
|
||||||
|
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||||
|
--chip-text: #000000;
|
||||||
|
--spinner-border: #8d8d8d5c;
|
||||||
|
--spinner-top: #b0b0b0;
|
||||||
|
--silence-bg: #f3f3f3;
|
||||||
|
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||||
|
--button-bg: #ffffff;
|
||||||
|
--button-border: #e9e9e9;
|
||||||
|
--wave-stroke: #000000;
|
||||||
|
--label-dia-text: #868686;
|
||||||
|
--label-trans-text: #111111;
|
||||||
|
}
|
||||||
|
|
||||||
|
html.is-extension
|
||||||
|
{
|
||||||
|
width: 350px;
|
||||||
|
height: 500px;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||||
|
margin: 0;
|
||||||
|
text-align: center;
|
||||||
|
background-color: var(--bg);
|
||||||
|
color: var(--text);
|
||||||
|
height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Record button */
|
||||||
|
#recordButton {
|
||||||
|
width: 50px;
|
||||||
|
height: 50px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: var(--button-bg);
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
border: 1px solid var(--button-border);
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton.recording {
|
||||||
|
width: 180px;
|
||||||
|
border-radius: 40px;
|
||||||
|
justify-content: flex-start;
|
||||||
|
padding-left: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton:active {
|
||||||
|
transform: scale(0.95);
|
||||||
|
}
|
||||||
|
|
||||||
|
.shape-container {
|
||||||
|
width: 25px;
|
||||||
|
height: 25px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.shape {
|
||||||
|
width: 25px;
|
||||||
|
height: 25px;
|
||||||
|
background-color: rgb(209, 61, 53);
|
||||||
|
border-radius: 50%;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton:disabled .shape {
|
||||||
|
background-color: #6e6d6d;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton.recording .shape {
|
||||||
|
border-radius: 5px;
|
||||||
|
width: 25px;
|
||||||
|
height: 25px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Recording elements */
|
||||||
|
.recording-info {
|
||||||
|
display: none;
|
||||||
|
align-items: center;
|
||||||
|
margin-left: 15px;
|
||||||
|
flex-grow: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#recordButton.recording .recording-info {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
|
||||||
|
.wave-container {
|
||||||
|
width: 60px;
|
||||||
|
height: 30px;
|
||||||
|
position: relative;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
#waveCanvas {
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.timer {
|
||||||
|
font-size: 14px;
|
||||||
|
font-weight: 500;
|
||||||
|
color: var(--text);
|
||||||
|
margin-left: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#status {
|
||||||
|
margin-top: 15px;
|
||||||
|
font-size: 16px;
|
||||||
|
color: var(--text);
|
||||||
|
margin-bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.header-container {
|
||||||
|
position: sticky;
|
||||||
|
top: 0;
|
||||||
|
background-color: var(--bg);
|
||||||
|
z-index: 100;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Settings */
|
||||||
|
.settings-container {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
gap: 15px;
|
||||||
|
position: relative;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.buttons-container {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings {
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
align-items: flex-start;
|
||||||
|
gap: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings-toggle {
|
||||||
|
width: 40px;
|
||||||
|
height: 40px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 50%;
|
||||||
|
background-color: var(--button-bg);
|
||||||
|
border: 1px solid var(--button-border);
|
||||||
|
cursor: pointer;
|
||||||
|
display: none;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
transition: all 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings-toggle:hover {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings-toggle.active {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings-toggle img {
|
||||||
|
width: 20px;
|
||||||
|
height: 20px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 10000px) {
|
||||||
|
.settings-toggle {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings {
|
||||||
|
display: none;
|
||||||
|
background: var(--bg);
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
border-radius: 18px;
|
||||||
|
padding: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings.visible {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 600px) {
|
||||||
|
.settings-container {
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.buttons-container {
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
gap: 15px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.field {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: flex-start;
|
||||||
|
gap: 3px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chunkSelector,
|
||||||
|
#websocketInput,
|
||||||
|
#themeSelector,
|
||||||
|
#microphoneSelect {
|
||||||
|
font-size: 16px;
|
||||||
|
padding: 5px 8px;
|
||||||
|
border-radius: 8px;
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
background-color: var(--button-bg);
|
||||||
|
color: var(--text);
|
||||||
|
max-height: 30px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#microphoneSelect {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 190px;
|
||||||
|
min-width: 120px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chunkSelector:focus,
|
||||||
|
#websocketInput:focus,
|
||||||
|
#themeSelector:focus,
|
||||||
|
#microphoneSelect:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #007bff;
|
||||||
|
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
label {
|
||||||
|
font-size: 13px;
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
.ws-default {
|
||||||
|
font-size: 12px;
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Segmented pill control for Theme */
|
||||||
|
.segmented {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: stretch;
|
||||||
|
border: 1px solid var(--button-border);
|
||||||
|
background-color: var(--button-bg);
|
||||||
|
border-radius: 999px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented input[type="radio"] {
|
||||||
|
position: absolute;
|
||||||
|
opacity: 0;
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.theme-selector-container {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
margin-top: 17px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented label {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 6px;
|
||||||
|
padding: 6px 12px;
|
||||||
|
font-size: 14px;
|
||||||
|
color: var(--muted);
|
||||||
|
cursor: pointer;
|
||||||
|
user-select: none;
|
||||||
|
transition: background-color 0.2s ease, color 0.2s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented label span {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented label:hover span {
|
||||||
|
display: inline;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented label:hover {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented img {
|
||||||
|
width: 16px;
|
||||||
|
height: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented input[type="radio"]:checked + label {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
color: var(--text);
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented input[type="radio"]:focus-visible + label,
|
||||||
|
.segmented input[type="radio"]:focus + label {
|
||||||
|
outline: 2px solid #007bff;
|
||||||
|
outline-offset: 2px;
|
||||||
|
border-radius: 999px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.transcript-container {
|
||||||
|
flex: 1;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 20px;
|
||||||
|
scrollbar-width: none;
|
||||||
|
-ms-overflow-style: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.transcript-container::-webkit-scrollbar {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Transcript area */
|
||||||
|
#linesTranscript {
|
||||||
|
margin: 0 auto;
|
||||||
|
max-width: 700px;
|
||||||
|
text-align: left;
|
||||||
|
font-size: 16px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#linesTranscript p {
|
||||||
|
margin: 0px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#linesTranscript strong {
|
||||||
|
color: var(--text);
|
||||||
|
}
|
||||||
|
|
||||||
|
#speaker {
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
border-radius: 100px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.label_diarization {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
border-radius: 100px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
margin-left: 10px;
|
||||||
|
display: inline-block;
|
||||||
|
white-space: nowrap;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
color: var(--label-dia-text);
|
||||||
|
}
|
||||||
|
|
||||||
|
.label_transcription {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
border-radius: 100px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
display: inline-block;
|
||||||
|
white-space: nowrap;
|
||||||
|
margin-left: 10px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
color: var(--label-trans-text);
|
||||||
|
}
|
||||||
|
|
||||||
|
.label_translation {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
display: inline-flex;
|
||||||
|
border-radius: 10px;
|
||||||
|
padding: 4px 8px;
|
||||||
|
margin-top: 4px;
|
||||||
|
font-size: 14px;
|
||||||
|
color: var(--text);
|
||||||
|
align-items: flex-start;
|
||||||
|
gap: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.lag-diarization-value {
|
||||||
|
margin-left: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.label_translation img {
|
||||||
|
margin-top: 2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.label_translation img {
|
||||||
|
width: 12px;
|
||||||
|
height: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#timeInfo {
|
||||||
|
color: var(--muted);
|
||||||
|
margin-left: 0px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.textcontent {
|
||||||
|
font-size: 16px;
|
||||||
|
padding-left: 10px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
margin-top: 1px;
|
||||||
|
padding-top: 5px;
|
||||||
|
border-radius: 0px 0px 0px 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.buffer_diarization {
|
||||||
|
color: var(--label-dia-text);
|
||||||
|
}
|
||||||
|
|
||||||
|
.buffer_transcription {
|
||||||
|
color: #7474748c;
|
||||||
|
margin-left: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.spinner {
|
||||||
|
display: inline-block;
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
border: 2px solid var(--spinner-border);
|
||||||
|
border-top: 2px solid var(--spinner-top);
|
||||||
|
border-radius: 50%;
|
||||||
|
animation: spin 0.7s linear infinite;
|
||||||
|
vertical-align: middle;
|
||||||
|
margin-bottom: 2px;
|
||||||
|
margin-right: 5px;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes spin {
|
||||||
|
to {
|
||||||
|
transform: rotate(360deg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.silence {
|
||||||
|
color: var(--muted);
|
||||||
|
background-color: var(--silence-bg);
|
||||||
|
font-size: 13px;
|
||||||
|
border-radius: 30px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loading {
|
||||||
|
color: var(--muted);
|
||||||
|
background-color: var(--loading-bg);
|
||||||
|
border-radius: 8px 8px 8px 0px;
|
||||||
|
padding: 2px 10px;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-bottom: 0px;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* for smaller screens */
|
||||||
|
@media (max-width: 200px) {
|
||||||
|
.header-container {
|
||||||
|
padding: 15px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings-container {
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.buttons-container {
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings {
|
||||||
|
justify-content: center;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.field {
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
#websocketInput,
|
||||||
|
#microphoneSelect {
|
||||||
|
min-width: 100px;
|
||||||
|
max-width: 160px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.theme-selector-container {
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.transcript-container {
|
||||||
|
padding: 15px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (max-width: 480px) {
|
||||||
|
.header-container {
|
||||||
|
padding: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.settings {
|
||||||
|
flex-direction: column;
|
||||||
|
align-items: center;
|
||||||
|
gap: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#websocketInput,
|
||||||
|
#microphoneSelect {
|
||||||
|
max-width: 140px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented label {
|
||||||
|
padding: 4px 8px;
|
||||||
|
font-size: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.segmented img {
|
||||||
|
width: 14px;
|
||||||
|
height: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.transcript-container {
|
||||||
|
padding: 10px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
.label_language {
|
||||||
|
background-color: var(--chip-bg);
|
||||||
|
margin-bottom: 0px;
|
||||||
|
border-radius: 100px;
|
||||||
|
padding: 2px 8px;
|
||||||
|
margin-left: 10px;
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 4px;
|
||||||
|
font-size: 14px;
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
.speaker-badge {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
width: 16px;
|
||||||
|
height: 16px;
|
||||||
|
margin-left: -5px;
|
||||||
|
border-radius: 50%;
|
||||||
|
font-size: 11px;
|
||||||
|
line-height: 1;
|
||||||
|
font-weight: 800;
|
||||||
|
color: var(--muted);
|
||||||
|
}
|
||||||
@@ -4,679 +4,76 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>Audio Transcription</title>
|
<title>WhisperLiveKit</title>
|
||||||
<style>
|
<link rel="stylesheet" href="live_transcription.css" />
|
||||||
body {
|
|
||||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
|
||||||
margin: 20px;
|
|
||||||
text-align: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton {
|
|
||||||
width: 50px;
|
|
||||||
height: 50px;
|
|
||||||
border: none;
|
|
||||||
border-radius: 50%;
|
|
||||||
background-color: white;
|
|
||||||
cursor: pointer;
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
border: 1px solid rgb(233, 233, 233);
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
position: relative;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton.recording {
|
|
||||||
width: 180px;
|
|
||||||
border-radius: 40px;
|
|
||||||
justify-content: flex-start;
|
|
||||||
padding-left: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton:active {
|
|
||||||
transform: scale(0.95);
|
|
||||||
}
|
|
||||||
|
|
||||||
.shape-container {
|
|
||||||
width: 25px;
|
|
||||||
height: 25px;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
flex-shrink: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.shape {
|
|
||||||
width: 25px;
|
|
||||||
height: 25px;
|
|
||||||
background-color: rgb(209, 61, 53);
|
|
||||||
border-radius: 50%;
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton:disabled .shape {
|
|
||||||
background-color: #6e6d6d;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton.recording .shape {
|
|
||||||
border-radius: 5px;
|
|
||||||
width: 25px;
|
|
||||||
height: 25px;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Recording elements */
|
|
||||||
.recording-info {
|
|
||||||
display: none;
|
|
||||||
align-items: center;
|
|
||||||
margin-left: 15px;
|
|
||||||
flex-grow: 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
#recordButton.recording .recording-info {
|
|
||||||
display: flex;
|
|
||||||
}
|
|
||||||
|
|
||||||
.wave-container {
|
|
||||||
width: 60px;
|
|
||||||
height: 30px;
|
|
||||||
position: relative;
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
#waveCanvas {
|
|
||||||
width: 100%;
|
|
||||||
height: 100%;
|
|
||||||
}
|
|
||||||
|
|
||||||
.timer {
|
|
||||||
font-size: 14px;
|
|
||||||
font-weight: 500;
|
|
||||||
color: #333;
|
|
||||||
margin-left: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#status {
|
|
||||||
margin-top: 20px;
|
|
||||||
font-size: 16px;
|
|
||||||
color: #333;
|
|
||||||
}
|
|
||||||
|
|
||||||
.settings-container {
|
|
||||||
display: flex;
|
|
||||||
justify-content: center;
|
|
||||||
align-items: center;
|
|
||||||
gap: 15px;
|
|
||||||
margin-top: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.settings {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
align-items: flex-start;
|
|
||||||
gap: 5px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chunkSelector,
|
|
||||||
#websocketInput {
|
|
||||||
font-size: 16px;
|
|
||||||
padding: 5px;
|
|
||||||
border-radius: 5px;
|
|
||||||
border: 1px solid #ddd;
|
|
||||||
background-color: #ffffff;
|
|
||||||
max-height: 30px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#websocketInput {
|
|
||||||
width: 200px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chunkSelector:focus,
|
|
||||||
#websocketInput:focus {
|
|
||||||
outline: none;
|
|
||||||
border-color: #007bff;
|
|
||||||
}
|
|
||||||
|
|
||||||
label {
|
|
||||||
font-size: 14px;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Speaker-labeled transcript area */
|
|
||||||
#linesTranscript {
|
|
||||||
margin: 20px auto;
|
|
||||||
max-width: 700px;
|
|
||||||
text-align: left;
|
|
||||||
font-size: 16px;
|
|
||||||
}
|
|
||||||
|
|
||||||
#linesTranscript p {
|
|
||||||
margin: 0px 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
#linesTranscript strong {
|
|
||||||
color: #333;
|
|
||||||
}
|
|
||||||
|
|
||||||
#speaker {
|
|
||||||
border: 1px solid rgb(229, 229, 229);
|
|
||||||
border-radius: 100px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 0px;
|
|
||||||
}
|
|
||||||
.label_diarization {
|
|
||||||
background-color: #ffffff66;
|
|
||||||
border-radius: 8px 8px 8px 8px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
margin-left: 10px;
|
|
||||||
display: inline-block;
|
|
||||||
white-space: nowrap;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 0px;
|
|
||||||
color: rgb(134, 134, 134)
|
|
||||||
}
|
|
||||||
|
|
||||||
.label_transcription {
|
|
||||||
background-color: #ffffff66;
|
|
||||||
border-radius: 8px 8px 8px 8px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
display: inline-block;
|
|
||||||
white-space: nowrap;
|
|
||||||
margin-left: 10px;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 0px;
|
|
||||||
color: #000000
|
|
||||||
}
|
|
||||||
|
|
||||||
#timeInfo {
|
|
||||||
color: #666;
|
|
||||||
margin-left: 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.textcontent {
|
|
||||||
font-size: 16px;
|
|
||||||
/* margin-left: 10px; */
|
|
||||||
padding-left: 10px;
|
|
||||||
margin-bottom: 10px;
|
|
||||||
margin-top: 1px;
|
|
||||||
padding-top: 5px;
|
|
||||||
border-radius: 0px 0px 0px 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.buffer_diarization {
|
|
||||||
color: rgb(134, 134, 134);
|
|
||||||
margin-left: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.buffer_transcription {
|
|
||||||
color: #7474748c;
|
|
||||||
margin-left: 4px;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
.spinner {
|
|
||||||
display: inline-block;
|
|
||||||
width: 8px;
|
|
||||||
height: 8px;
|
|
||||||
border: 2px solid #8d8d8d5c;
|
|
||||||
border-top: 2px solid #6c6c6ce5;
|
|
||||||
border-radius: 50%;
|
|
||||||
animation: spin 0.6s linear infinite;
|
|
||||||
vertical-align: middle;
|
|
||||||
margin-bottom: 2px;
|
|
||||||
margin-right: 5px;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes spin {
|
|
||||||
to {
|
|
||||||
transform: rotate(360deg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
.silence {
|
|
||||||
color: #666;
|
|
||||||
background-color: #f3f3f3;
|
|
||||||
font-size: 13px;
|
|
||||||
border-radius: 30px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loading {
|
|
||||||
color: #666;
|
|
||||||
background-color: #ff4d4d0f;
|
|
||||||
border-radius: 8px 8px 8px 0px;
|
|
||||||
padding: 2px 10px;
|
|
||||||
font-size: 14px;
|
|
||||||
margin-bottom: 0px;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body>
|
<body>
|
||||||
|
<div class="header-container">
|
||||||
|
<div class="settings-container">
|
||||||
|
<div class="buttons-container">
|
||||||
|
<button id="recordButton">
|
||||||
|
<div class="shape-container">
|
||||||
|
<div class="shape"></div>
|
||||||
|
</div>
|
||||||
|
<div class="recording-info">
|
||||||
|
<div class="wave-container">
|
||||||
|
<canvas id="waveCanvas"></canvas>
|
||||||
|
</div>
|
||||||
|
<div class="timer">00:00</div>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
|
|
||||||
<div class="settings-container">
|
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
|
||||||
<button id="recordButton">
|
<img src="web/src/settings.svg" alt="Settings" />
|
||||||
<div class="shape-container">
|
</button>
|
||||||
<div class="shape"></div>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="recording-info">
|
|
||||||
<div class="wave-container">
|
<div class="settings">
|
||||||
<canvas id="waveCanvas"></canvas>
|
<div class="field">
|
||||||
|
<label for="websocketInput">Websocket URL</label>
|
||||||
|
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="field">
|
||||||
|
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
|
||||||
|
<select id="microphoneSelect">
|
||||||
|
<option value="">Default Microphone</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="theme-selector-container">
|
||||||
|
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||||
|
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||||
|
<label for="theme-system" title="System">
|
||||||
|
<img src="/web/src/system_mode.svg" alt="" />
|
||||||
|
<span>System</span>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||||
|
<label for="theme-light" title="Light">
|
||||||
|
<img src="/web/src/light_mode.svg" alt="" />
|
||||||
|
<span>Light</span>
|
||||||
|
</label>
|
||||||
|
|
||||||
|
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||||
|
<label for="theme-dark" title="Dark">
|
||||||
|
<img src="/web/src/dark_mode.svg" alt="" />
|
||||||
|
<span>Dark</span>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="timer">00:00</div>
|
|
||||||
</div>
|
|
||||||
</button>
|
|
||||||
<div class="settings">
|
|
||||||
<div>
|
|
||||||
<label for="chunkSelector">Chunk size (ms):</label>
|
|
||||||
<select id="chunkSelector">
|
|
||||||
<option value="500">500 ms</option>
|
|
||||||
<option value="1000" selected>1000 ms</option>
|
|
||||||
<option value="2000">2000 ms</option>
|
|
||||||
<option value="3000">3000 ms</option>
|
|
||||||
<option value="4000">4000 ms</option>
|
|
||||||
<option value="5000">5000 ms</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label for="websocketInput">WebSocket URL:</label>
|
|
||||||
<input id="websocketInput" type="text" />
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<p id="status"></p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<p id="status"></p>
|
<div class="transcript-container">
|
||||||
|
<div id="linesTranscript"></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- Speaker-labeled transcript -->
|
<script src="live_transcription.js"></script>
|
||||||
<div id="linesTranscript"></div>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
let isRecording = false;
|
|
||||||
let websocket = null;
|
|
||||||
let recorder = null;
|
|
||||||
let chunkDuration = 1000;
|
|
||||||
let websocketUrl = "ws://localhost:8000/asr";
|
|
||||||
let userClosing = false;
|
|
||||||
let startTime = null;
|
|
||||||
let timerInterval = null;
|
|
||||||
let audioContext = null;
|
|
||||||
let analyser = null;
|
|
||||||
let microphone = null;
|
|
||||||
let waveCanvas = document.getElementById("waveCanvas");
|
|
||||||
let waveCtx = waveCanvas.getContext("2d");
|
|
||||||
let animationFrame = null;
|
|
||||||
let waitingForStop = false;
|
|
||||||
let lastReceivedData = null;
|
|
||||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
|
||||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
|
||||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
|
||||||
|
|
||||||
const statusText = document.getElementById("status");
|
|
||||||
const recordButton = document.getElementById("recordButton");
|
|
||||||
const chunkSelector = document.getElementById("chunkSelector");
|
|
||||||
const websocketInput = document.getElementById("websocketInput");
|
|
||||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
|
||||||
const timerElement = document.querySelector(".timer");
|
|
||||||
|
|
||||||
const host = window.location.hostname || "localhost";
|
|
||||||
const port = window.location.port || "8000";
|
|
||||||
const protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
|
||||||
const defaultWebSocketUrl = `${protocol}://${host}:${port}/asr`;
|
|
||||||
websocketInput.value = defaultWebSocketUrl;
|
|
||||||
websocketUrl = defaultWebSocketUrl;
|
|
||||||
|
|
||||||
chunkSelector.addEventListener("change", () => {
|
|
||||||
chunkDuration = parseInt(chunkSelector.value);
|
|
||||||
});
|
|
||||||
|
|
||||||
websocketInput.addEventListener("change", () => {
|
|
||||||
const urlValue = websocketInput.value.trim();
|
|
||||||
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
|
||||||
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
websocketUrl = urlValue;
|
|
||||||
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
|
||||||
});
|
|
||||||
|
|
||||||
function setupWebSocket() {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
try {
|
|
||||||
websocket = new WebSocket(websocketUrl);
|
|
||||||
} catch (error) {
|
|
||||||
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
|
||||||
reject(error);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
websocket.onopen = () => {
|
|
||||||
statusText.textContent = "Connected to server.";
|
|
||||||
resolve();
|
|
||||||
};
|
|
||||||
|
|
||||||
websocket.onclose = () => {
|
|
||||||
if (userClosing) {
|
|
||||||
if (waitingForStop) {
|
|
||||||
statusText.textContent = "Processing finalized or connection closed.";
|
|
||||||
if (lastReceivedData) {
|
|
||||||
renderLinesWithBuffer(
|
|
||||||
lastReceivedData.lines || [],
|
|
||||||
lastReceivedData.buffer_diarization || "",
|
|
||||||
lastReceivedData.buffer_transcription || "",
|
|
||||||
0, 0, true // isFinalizing = true
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If ready_to_stop was received, statusText is already "Finished processing..."
|
|
||||||
// and waitingForStop is false.
|
|
||||||
} else {
|
|
||||||
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
|
||||||
if (isRecording) {
|
|
||||||
stopRecording();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
isRecording = false;
|
|
||||||
waitingForStop = false;
|
|
||||||
userClosing = false;
|
|
||||||
lastReceivedData = null;
|
|
||||||
websocket = null;
|
|
||||||
updateUI();
|
|
||||||
};
|
|
||||||
|
|
||||||
websocket.onerror = () => {
|
|
||||||
statusText.textContent = "Error connecting to WebSocket.";
|
|
||||||
reject(new Error("Error connecting to WebSocket"));
|
|
||||||
};
|
|
||||||
|
|
||||||
// Handle messages from server
|
|
||||||
websocket.onmessage = (event) => {
|
|
||||||
const data = JSON.parse(event.data);
|
|
||||||
|
|
||||||
// Check for status messages
|
|
||||||
if (data.type === "ready_to_stop") {
|
|
||||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
|
||||||
waitingForStop = false;
|
|
||||||
|
|
||||||
if (lastReceivedData) {
|
|
||||||
renderLinesWithBuffer(
|
|
||||||
lastReceivedData.lines || [],
|
|
||||||
lastReceivedData.buffer_diarization || "",
|
|
||||||
lastReceivedData.buffer_transcription || "",
|
|
||||||
0, // No more lag
|
|
||||||
0, // No more lag
|
|
||||||
true // isFinalizing = true
|
|
||||||
);
|
|
||||||
}
|
|
||||||
statusText.textContent = "Finished processing audio! Ready to record again.";
|
|
||||||
recordButton.disabled = false;
|
|
||||||
|
|
||||||
if (websocket) {
|
|
||||||
websocket.close(); // will trigger onclose
|
|
||||||
// websocket = null; // onclose handle setting websocket to null
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
lastReceivedData = data;
|
|
||||||
|
|
||||||
// Handle normal transcription updates
|
|
||||||
const {
|
|
||||||
lines = [],
|
|
||||||
buffer_transcription = "",
|
|
||||||
buffer_diarization = "",
|
|
||||||
remaining_time_transcription = 0,
|
|
||||||
remaining_time_diarization = 0,
|
|
||||||
status = "active_transcription"
|
|
||||||
} = data;
|
|
||||||
|
|
||||||
renderLinesWithBuffer(
|
|
||||||
lines,
|
|
||||||
buffer_diarization,
|
|
||||||
buffer_transcription,
|
|
||||||
remaining_time_diarization,
|
|
||||||
remaining_time_transcription,
|
|
||||||
false,
|
|
||||||
status
|
|
||||||
);
|
|
||||||
};
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
function renderLinesWithBuffer(lines, buffer_diarization, buffer_transcription, remaining_time_diarization, remaining_time_transcription, isFinalizing = false, current_status = "active_transcription") {
|
|
||||||
if (current_status === "no_audio_detected") {
|
|
||||||
linesTranscriptDiv.innerHTML = "<p style='text-align: center; color: #666; margin-top: 20px;'><em>No audio detected...</em></p>";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const linesHtml = lines.map((item, idx) => {
|
|
||||||
let timeInfo = "";
|
|
||||||
if (item.beg !== undefined && item.end !== undefined) {
|
|
||||||
timeInfo = ` ${item.beg} - ${item.end}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
let speakerLabel = "";
|
|
||||||
if (item.speaker === -2) {
|
|
||||||
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
|
|
||||||
} else if (item.speaker == 0 && !isFinalizing) {
|
|
||||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${remaining_time_diarization} second(s) of audio are undergoing diarization</span></span>`;
|
|
||||||
} else if (item.speaker == -1) {
|
|
||||||
speakerLabel = `<span id="speaker">Speaker 1<span id='timeInfo'>${timeInfo}</span></span>`;
|
|
||||||
} else if (item.speaker !== -1 && item.speaker !== 0) {
|
|
||||||
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
let currentLineText = item.text || "";
|
|
||||||
|
|
||||||
if (idx === lines.length - 1) {
|
|
||||||
if (!isFinalizing) {
|
|
||||||
if (remaining_time_transcription > 0) {
|
|
||||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'>${remaining_time_transcription}s</span></span>`;
|
|
||||||
}
|
|
||||||
if (buffer_diarization && remaining_time_diarization > 0) {
|
|
||||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'>${remaining_time_diarization}s</span></span>`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (buffer_diarization) {
|
|
||||||
if (isFinalizing) {
|
|
||||||
currentLineText += (currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
|
||||||
} else {
|
|
||||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (buffer_transcription) {
|
|
||||||
if (isFinalizing) {
|
|
||||||
currentLineText += (currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") + buffer_transcription.trim();
|
|
||||||
} else {
|
|
||||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
|
||||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
|
||||||
: `<p>${speakerLabel}<br/></p>`;
|
|
||||||
}).join("");
|
|
||||||
|
|
||||||
linesTranscriptDiv.innerHTML = linesHtml;
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateTimer() {
|
|
||||||
if (!startTime) return;
|
|
||||||
|
|
||||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
|
||||||
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
|
||||||
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
|
||||||
timerElement.textContent = `${minutes}:${seconds}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
function drawWaveform() {
|
|
||||||
if (!analyser) return;
|
|
||||||
|
|
||||||
const bufferLength = analyser.frequencyBinCount;
|
|
||||||
const dataArray = new Uint8Array(bufferLength);
|
|
||||||
analyser.getByteTimeDomainData(dataArray);
|
|
||||||
|
|
||||||
waveCtx.clearRect(0, 0, waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1));
|
|
||||||
waveCtx.lineWidth = 1;
|
|
||||||
waveCtx.strokeStyle = 'rgb(0, 0, 0)';
|
|
||||||
waveCtx.beginPath();
|
|
||||||
|
|
||||||
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
|
||||||
let x = 0;
|
|
||||||
|
|
||||||
for (let i = 0; i < bufferLength; i++) {
|
|
||||||
const v = dataArray[i] / 128.0;
|
|
||||||
const y = v * (waveCanvas.height / (window.devicePixelRatio || 1)) / 2;
|
|
||||||
|
|
||||||
if (i === 0) {
|
|
||||||
waveCtx.moveTo(x, y);
|
|
||||||
} else {
|
|
||||||
waveCtx.lineTo(x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
x += sliceWidth;
|
|
||||||
}
|
|
||||||
|
|
||||||
waveCtx.lineTo(waveCanvas.width / (window.devicePixelRatio || 1), waveCanvas.height / (window.devicePixelRatio || 1) / 2);
|
|
||||||
waveCtx.stroke();
|
|
||||||
|
|
||||||
animationFrame = requestAnimationFrame(drawWaveform);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function startRecording() {
|
|
||||||
try {
|
|
||||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
|
||||||
|
|
||||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
|
||||||
analyser = audioContext.createAnalyser();
|
|
||||||
analyser.fftSize = 256;
|
|
||||||
microphone = audioContext.createMediaStreamSource(stream);
|
|
||||||
microphone.connect(analyser);
|
|
||||||
|
|
||||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
|
||||||
recorder.ondataavailable = (e) => {
|
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
|
||||||
websocket.send(e.data);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
recorder.start(chunkDuration);
|
|
||||||
|
|
||||||
startTime = Date.now();
|
|
||||||
timerInterval = setInterval(updateTimer, 1000);
|
|
||||||
drawWaveform();
|
|
||||||
|
|
||||||
isRecording = true;
|
|
||||||
updateUI();
|
|
||||||
} catch (err) {
|
|
||||||
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
|
||||||
console.error(err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function stopRecording() {
|
|
||||||
userClosing = true;
|
|
||||||
waitingForStop = true;
|
|
||||||
|
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
|
||||||
// Send empty audio buffer as stop signal
|
|
||||||
const emptyBlob = new Blob([], { type: 'audio/webm' });
|
|
||||||
websocket.send(emptyBlob);
|
|
||||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
|
||||||
}
|
|
||||||
|
|
||||||
if (recorder) {
|
|
||||||
recorder.stop();
|
|
||||||
recorder = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (microphone) {
|
|
||||||
microphone.disconnect();
|
|
||||||
microphone = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (analyser) {
|
|
||||||
analyser = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (audioContext && audioContext.state !== 'closed') {
|
|
||||||
try {
|
|
||||||
audioContext.close();
|
|
||||||
} catch (e) {
|
|
||||||
console.warn("Could not close audio context:", e);
|
|
||||||
}
|
|
||||||
audioContext = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (animationFrame) {
|
|
||||||
cancelAnimationFrame(animationFrame);
|
|
||||||
animationFrame = null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (timerInterval) {
|
|
||||||
clearInterval(timerInterval);
|
|
||||||
timerInterval = null;
|
|
||||||
}
|
|
||||||
timerElement.textContent = "00:00";
|
|
||||||
startTime = null;
|
|
||||||
|
|
||||||
|
|
||||||
isRecording = false;
|
|
||||||
updateUI();
|
|
||||||
}
|
|
||||||
|
|
||||||
async function toggleRecording() {
|
|
||||||
if (!isRecording) {
|
|
||||||
if (waitingForStop) {
|
|
||||||
console.log("Waiting for stop, early return");
|
|
||||||
return; // Early return, UI is already updated
|
|
||||||
}
|
|
||||||
console.log("Connecting to WebSocket");
|
|
||||||
try {
|
|
||||||
// If we have an active WebSocket that's still processing, just restart audio capture
|
|
||||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
|
||||||
await startRecording();
|
|
||||||
} else {
|
|
||||||
// If no active WebSocket or it's closed, create new one
|
|
||||||
await setupWebSocket();
|
|
||||||
await startRecording();
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
|
||||||
console.error(err);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
console.log("Stopping recording");
|
|
||||||
stopRecording();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function updateUI() {
|
|
||||||
recordButton.classList.toggle("recording", isRecording);
|
|
||||||
recordButton.disabled = waitingForStop;
|
|
||||||
|
|
||||||
if (waitingForStop) {
|
|
||||||
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
|
||||||
statusText.textContent = "Please wait for processing to complete...";
|
|
||||||
}
|
|
||||||
} else if (isRecording) {
|
|
||||||
statusText.textContent = "Recording...";
|
|
||||||
} else {
|
|
||||||
if (statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
|
||||||
statusText.textContent !== "Processing finalized or connection closed.") {
|
|
||||||
statusText.textContent = "Click to start transcription";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!waitingForStop) {
|
|
||||||
recordButton.disabled = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
recordButton.addEventListener("click", toggleRecording);
|
|
||||||
</script>
|
|
||||||
</body>
|
</body>
|
||||||
|
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
803
whisperlivekit/web/live_transcription.js
Normal file
@@ -0,0 +1,803 @@
|
|||||||
|
const isExtension = typeof chrome !== 'undefined' && chrome.runtime && chrome.runtime.getURL;
|
||||||
|
if (isExtension) {
|
||||||
|
document.documentElement.classList.add('is-extension');
|
||||||
|
}
|
||||||
|
const isWebContext = !isExtension;
|
||||||
|
|
||||||
|
let isRecording = false;
|
||||||
|
let websocket = null;
|
||||||
|
let recorder = null;
|
||||||
|
let chunkDuration = 100;
|
||||||
|
let websocketUrl = "ws://localhost:8000/asr";
|
||||||
|
let userClosing = false;
|
||||||
|
let wakeLock = null;
|
||||||
|
let startTime = null;
|
||||||
|
let timerInterval = null;
|
||||||
|
let audioContext = null;
|
||||||
|
let analyser = null;
|
||||||
|
let microphone = null;
|
||||||
|
let workletNode = null;
|
||||||
|
let recorderWorker = null;
|
||||||
|
let waveCanvas = document.getElementById("waveCanvas");
|
||||||
|
let waveCtx = waveCanvas.getContext("2d");
|
||||||
|
let animationFrame = null;
|
||||||
|
let waitingForStop = false;
|
||||||
|
let lastReceivedData = null;
|
||||||
|
let lastSignature = null;
|
||||||
|
let availableMicrophones = [];
|
||||||
|
let selectedMicrophoneId = null;
|
||||||
|
let serverUseAudioWorklet = null;
|
||||||
|
let configReadyResolve;
|
||||||
|
const configReady = new Promise((r) => (configReadyResolve = r));
|
||||||
|
let outputAudioContext = null;
|
||||||
|
let audioSource = null;
|
||||||
|
|
||||||
|
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||||
|
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||||
|
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||||
|
|
||||||
|
const statusText = document.getElementById("status");
|
||||||
|
const recordButton = document.getElementById("recordButton");
|
||||||
|
const chunkSelector = document.getElementById("chunkSelector");
|
||||||
|
const websocketInput = document.getElementById("websocketInput");
|
||||||
|
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
|
||||||
|
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||||
|
const timerElement = document.querySelector(".timer");
|
||||||
|
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||||
|
const microphoneSelect = document.getElementById("microphoneSelect");
|
||||||
|
|
||||||
|
const settingsToggle = document.getElementById("settingsToggle");
|
||||||
|
const settingsDiv = document.querySelector(".settings");
|
||||||
|
|
||||||
|
// if (isExtension) {
|
||||||
|
// chrome.runtime.onInstalled.addListener((details) => {
|
||||||
|
// if (details.reason.search(/install/g) === -1) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
// chrome.tabs.create({
|
||||||
|
// url: chrome.runtime.getURL("welcome.html"),
|
||||||
|
// active: true
|
||||||
|
// });
|
||||||
|
// });
|
||||||
|
// }
|
||||||
|
|
||||||
|
const translationIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12px" viewBox="0 -960 960 960" width="12px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>`
|
||||||
|
const silenceIcon = `<svg xmlns="http://www.w3.org/2000/svg" style="vertical-align: text-bottom;" height="14px" viewBox="0 -960 960 960" width="14px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>`;
|
||||||
|
const languageIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12" viewBox="0 -960 960 960" width="12" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>`
|
||||||
|
const speakerIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="16px" style="vertical-align: text-bottom;" viewBox="0 -960 960 960" width="16px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>`;
|
||||||
|
|
||||||
|
function getWaveStroke() {
|
||||||
|
const styles = getComputedStyle(document.documentElement);
|
||||||
|
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||||
|
return v || "#000";
|
||||||
|
}
|
||||||
|
|
||||||
|
let waveStroke = getWaveStroke();
|
||||||
|
function updateWaveStroke() {
|
||||||
|
waveStroke = getWaveStroke();
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyTheme(pref) {
|
||||||
|
if (pref === "light") {
|
||||||
|
document.documentElement.setAttribute("data-theme", "light");
|
||||||
|
} else if (pref === "dark") {
|
||||||
|
document.documentElement.setAttribute("data-theme", "dark");
|
||||||
|
} else {
|
||||||
|
document.documentElement.removeAttribute("data-theme");
|
||||||
|
}
|
||||||
|
updateWaveStroke();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persisted theme preference
|
||||||
|
const savedThemePref = localStorage.getItem("themePreference") || "system";
|
||||||
|
applyTheme(savedThemePref);
|
||||||
|
if (themeRadios.length) {
|
||||||
|
themeRadios.forEach((r) => {
|
||||||
|
r.checked = r.value === savedThemePref;
|
||||||
|
r.addEventListener("change", () => {
|
||||||
|
if (r.checked) {
|
||||||
|
localStorage.setItem("themePreference", r.value);
|
||||||
|
applyTheme(r.value);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// React to OS theme changes when in "system" mode
|
||||||
|
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
|
||||||
|
const handleOsThemeChange = () => {
|
||||||
|
const pref = localStorage.getItem("themePreference") || "system";
|
||||||
|
if (pref === "system") updateWaveStroke();
|
||||||
|
};
|
||||||
|
if (darkMq && darkMq.addEventListener) {
|
||||||
|
darkMq.addEventListener("change", handleOsThemeChange);
|
||||||
|
} else if (darkMq && darkMq.addListener) {
|
||||||
|
// deprecated, but included for Safari compatibility
|
||||||
|
darkMq.addListener(handleOsThemeChange);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function enumerateMicrophones() {
|
||||||
|
try {
|
||||||
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||||
|
stream.getTracks().forEach(track => track.stop());
|
||||||
|
|
||||||
|
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||||
|
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
|
||||||
|
|
||||||
|
populateMicrophoneSelect();
|
||||||
|
console.log(`Found ${availableMicrophones.length} microphone(s)`);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error enumerating microphones:', error);
|
||||||
|
statusText.textContent = "Error accessing microphones. Please grant permission.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function populateMicrophoneSelect() {
|
||||||
|
if (!microphoneSelect) return;
|
||||||
|
|
||||||
|
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
|
||||||
|
|
||||||
|
availableMicrophones.forEach((device, index) => {
|
||||||
|
const option = document.createElement('option');
|
||||||
|
option.value = device.deviceId;
|
||||||
|
option.textContent = device.label || `Microphone ${index + 1}`;
|
||||||
|
microphoneSelect.appendChild(option);
|
||||||
|
});
|
||||||
|
|
||||||
|
const savedMicId = localStorage.getItem('selectedMicrophone');
|
||||||
|
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
|
||||||
|
microphoneSelect.value = savedMicId;
|
||||||
|
selectedMicrophoneId = savedMicId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleMicrophoneChange() {
|
||||||
|
selectedMicrophoneId = microphoneSelect.value || null;
|
||||||
|
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
|
||||||
|
|
||||||
|
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
|
||||||
|
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
|
||||||
|
|
||||||
|
console.log(`Selected microphone: ${deviceName}`);
|
||||||
|
statusText.textContent = `Microphone changed to: ${deviceName}`;
|
||||||
|
|
||||||
|
if (isRecording) {
|
||||||
|
statusText.textContent = "Switching microphone... Please wait.";
|
||||||
|
stopRecording().then(() => {
|
||||||
|
setTimeout(() => {
|
||||||
|
toggleRecording();
|
||||||
|
}, 1000);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers
|
||||||
|
function fmt1(x) {
|
||||||
|
const n = Number(x);
|
||||||
|
return Number.isFinite(n) ? n.toFixed(1) : x;
|
||||||
|
}
|
||||||
|
|
||||||
|
let host, port, protocol;
|
||||||
|
port = 8000;
|
||||||
|
if (isExtension) {
|
||||||
|
host = "localhost";
|
||||||
|
protocol = "ws";
|
||||||
|
} else {
|
||||||
|
host = window.location.hostname || "localhost";
|
||||||
|
port = window.location.port;
|
||||||
|
protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||||
|
}
|
||||||
|
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
|
||||||
|
|
||||||
|
// Populate default caption and input
|
||||||
|
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
|
||||||
|
websocketInput.value = defaultWebSocketUrl;
|
||||||
|
websocketUrl = defaultWebSocketUrl;
|
||||||
|
|
||||||
|
// Optional chunk selector (guard for presence)
|
||||||
|
if (chunkSelector) {
|
||||||
|
chunkSelector.addEventListener("change", () => {
|
||||||
|
chunkDuration = parseInt(chunkSelector.value);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket input change handling
|
||||||
|
websocketInput.addEventListener("change", () => {
|
||||||
|
const urlValue = websocketInput.value.trim();
|
||||||
|
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
||||||
|
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
websocketUrl = urlValue;
|
||||||
|
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
||||||
|
});
|
||||||
|
|
||||||
|
function setupWebSocket() {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
try {
|
||||||
|
websocket = new WebSocket(websocketUrl);
|
||||||
|
} catch (error) {
|
||||||
|
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
||||||
|
reject(error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
websocket.onopen = () => {
|
||||||
|
statusText.textContent = "Connected to server.";
|
||||||
|
resolve();
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onclose = () => {
|
||||||
|
if (userClosing) {
|
||||||
|
if (waitingForStop) {
|
||||||
|
statusText.textContent = "Processing finalized or connection closed.";
|
||||||
|
if (lastReceivedData) {
|
||||||
|
renderLinesWithBuffer(
|
||||||
|
lastReceivedData.lines || [],
|
||||||
|
lastReceivedData.buffer_diarization || "",
|
||||||
|
lastReceivedData.buffer_transcription || "",
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||||
|
if (isRecording) {
|
||||||
|
stopRecording();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
isRecording = false;
|
||||||
|
waitingForStop = false;
|
||||||
|
userClosing = false;
|
||||||
|
lastReceivedData = null;
|
||||||
|
websocket = null;
|
||||||
|
updateUI();
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onerror = () => {
|
||||||
|
statusText.textContent = "Error connecting to WebSocket.";
|
||||||
|
reject(new Error("Error connecting to WebSocket"));
|
||||||
|
};
|
||||||
|
|
||||||
|
websocket.onmessage = (event) => {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
if (data.type === "config") {
|
||||||
|
serverUseAudioWorklet = !!data.useAudioWorklet;
|
||||||
|
statusText.textContent = serverUseAudioWorklet
|
||||||
|
? "Connected. Using AudioWorklet (PCM)."
|
||||||
|
: "Connected. Using MediaRecorder (WebM).";
|
||||||
|
if (configReadyResolve) configReadyResolve();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.type === "ready_to_stop") {
|
||||||
|
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||||
|
waitingForStop = false;
|
||||||
|
|
||||||
|
if (lastReceivedData) {
|
||||||
|
renderLinesWithBuffer(
|
||||||
|
lastReceivedData.lines || [],
|
||||||
|
lastReceivedData.buffer_diarization || "",
|
||||||
|
lastReceivedData.buffer_transcription || "",
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
}
|
||||||
|
statusText.textContent = "Finished processing audio! Ready to record again.";
|
||||||
|
recordButton.disabled = false;
|
||||||
|
|
||||||
|
if (websocket) {
|
||||||
|
websocket.close();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
lastReceivedData = data;
|
||||||
|
|
||||||
|
const {
|
||||||
|
lines = [],
|
||||||
|
buffer_transcription = "",
|
||||||
|
buffer_diarization = "",
|
||||||
|
remaining_time_transcription = 0,
|
||||||
|
remaining_time_diarization = 0,
|
||||||
|
status = "active_transcription",
|
||||||
|
} = data;
|
||||||
|
|
||||||
|
renderLinesWithBuffer(
|
||||||
|
lines,
|
||||||
|
buffer_diarization,
|
||||||
|
buffer_transcription,
|
||||||
|
remaining_time_diarization,
|
||||||
|
remaining_time_transcription,
|
||||||
|
false,
|
||||||
|
status
|
||||||
|
);
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderLinesWithBuffer(
|
||||||
|
lines,
|
||||||
|
buffer_diarization,
|
||||||
|
buffer_transcription,
|
||||||
|
remaining_time_diarization,
|
||||||
|
remaining_time_transcription,
|
||||||
|
isFinalizing = false,
|
||||||
|
current_status = "active_transcription"
|
||||||
|
) {
|
||||||
|
if (current_status === "no_audio_detected") {
|
||||||
|
linesTranscriptDiv.innerHTML =
|
||||||
|
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
|
||||||
|
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||||
|
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||||
|
const signature = JSON.stringify({
|
||||||
|
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
||||||
|
buffer_transcription: buffer_transcription || "",
|
||||||
|
buffer_diarization: buffer_diarization || "",
|
||||||
|
status: current_status,
|
||||||
|
showLoading,
|
||||||
|
showTransLag,
|
||||||
|
showDiaLag,
|
||||||
|
isFinalizing: !!isFinalizing,
|
||||||
|
});
|
||||||
|
if (lastSignature === signature) {
|
||||||
|
const t = document.querySelector(".lag-transcription-value");
|
||||||
|
if (t) t.textContent = fmt1(remaining_time_transcription);
|
||||||
|
const d = document.querySelector(".lag-diarization-value");
|
||||||
|
if (d) d.textContent = fmt1(remaining_time_diarization);
|
||||||
|
const ld = document.querySelector(".loading-diarization-value");
|
||||||
|
if (ld) ld.textContent = fmt1(remaining_time_diarization);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
lastSignature = signature;
|
||||||
|
|
||||||
|
const linesHtml = (lines || [])
|
||||||
|
.map((item, idx) => {
|
||||||
|
let timeInfo = "";
|
||||||
|
if (item.start !== undefined && item.end !== undefined) {
|
||||||
|
timeInfo = ` ${item.start} - ${item.end}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
let speakerLabel = "";
|
||||||
|
if (item.speaker === -2) {
|
||||||
|
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
|
} else if (item.speaker == 0 && !isFinalizing) {
|
||||||
|
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||||
|
remaining_time_diarization
|
||||||
|
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||||
|
} else if (item.speaker !== 0) {
|
||||||
|
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
|
||||||
|
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||||
|
|
||||||
|
if (item.detected_language) {
|
||||||
|
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let currentLineText = item.text || "";
|
||||||
|
|
||||||
|
if (idx === lines.length - 1) {
|
||||||
|
if (!isFinalizing && item.speaker !== -2) {
|
||||||
|
if (remaining_time_transcription > 0) {
|
||||||
|
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
||||||
|
remaining_time_transcription
|
||||||
|
)}</span>s</span></span>`;
|
||||||
|
}
|
||||||
|
if (buffer_diarization && remaining_time_diarization > 0) {
|
||||||
|
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
||||||
|
remaining_time_diarization
|
||||||
|
)}</span>s</span></span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (buffer_diarization) {
|
||||||
|
if (isFinalizing) {
|
||||||
|
currentLineText +=
|
||||||
|
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||||
|
} else {
|
||||||
|
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (buffer_transcription) {
|
||||||
|
if (isFinalizing) {
|
||||||
|
currentLineText +=
|
||||||
|
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
|
||||||
|
buffer_transcription.trim();
|
||||||
|
} else {
|
||||||
|
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (item.translation) {
|
||||||
|
currentLineText += `
|
||||||
|
<div>
|
||||||
|
<div class="label_translation">
|
||||||
|
${translationIcon}
|
||||||
|
<span>${item.translation}</span>
|
||||||
|
</div>
|
||||||
|
</div>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||||
|
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||||
|
: `<p>${speakerLabel}<br/></p>`;
|
||||||
|
})
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
linesTranscriptDiv.innerHTML = linesHtml;
|
||||||
|
const transcriptContainer = document.querySelector('.transcript-container');
|
||||||
|
if (transcriptContainer) {
|
||||||
|
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateTimer() {
|
||||||
|
if (!startTime) return;
|
||||||
|
|
||||||
|
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||||
|
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
||||||
|
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
||||||
|
timerElement.textContent = `${minutes}:${seconds}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function drawWaveform() {
|
||||||
|
if (!analyser) return;
|
||||||
|
|
||||||
|
const bufferLength = analyser.frequencyBinCount;
|
||||||
|
const dataArray = new Uint8Array(bufferLength);
|
||||||
|
analyser.getByteTimeDomainData(dataArray);
|
||||||
|
|
||||||
|
waveCtx.clearRect(
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||||
|
waveCanvas.height / (window.devicePixelRatio || 1)
|
||||||
|
);
|
||||||
|
waveCtx.lineWidth = 1;
|
||||||
|
waveCtx.strokeStyle = waveStroke;
|
||||||
|
waveCtx.beginPath();
|
||||||
|
|
||||||
|
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
||||||
|
let x = 0;
|
||||||
|
|
||||||
|
for (let i = 0; i < bufferLength; i++) {
|
||||||
|
const v = dataArray[i] / 128.0;
|
||||||
|
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
|
||||||
|
|
||||||
|
if (i === 0) {
|
||||||
|
waveCtx.moveTo(x, y);
|
||||||
|
} else {
|
||||||
|
waveCtx.lineTo(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
x += sliceWidth;
|
||||||
|
}
|
||||||
|
|
||||||
|
waveCtx.lineTo(
|
||||||
|
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||||
|
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
|
||||||
|
);
|
||||||
|
waveCtx.stroke();
|
||||||
|
|
||||||
|
animationFrame = requestAnimationFrame(drawWaveform);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function startRecording() {
|
||||||
|
try {
|
||||||
|
try {
|
||||||
|
wakeLock = await navigator.wakeLock.request("screen");
|
||||||
|
} catch (err) {
|
||||||
|
console.log("Error acquiring wake lock.");
|
||||||
|
}
|
||||||
|
|
||||||
|
let stream;
|
||||||
|
|
||||||
|
// chromium extension. in the future, both chrome page audio and mic will be used
|
||||||
|
if (isExtension) {
|
||||||
|
try {
|
||||||
|
stream = await new Promise((resolve, reject) => {
|
||||||
|
chrome.tabCapture.capture({audio: true}, (s) => {
|
||||||
|
if (s) {
|
||||||
|
resolve(s);
|
||||||
|
} else {
|
||||||
|
reject(new Error('Tab capture failed or not available'));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
outputAudioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
audioSource = outputAudioContext.createMediaStreamSource(stream);
|
||||||
|
audioSource.connect(outputAudioContext.destination);
|
||||||
|
} catch (audioError) {
|
||||||
|
console.warn('could not preserve system audio:', audioError);
|
||||||
|
}
|
||||||
|
|
||||||
|
statusText.textContent = "Using tab audio capture.";
|
||||||
|
} catch (tabError) {
|
||||||
|
console.log('Tab capture not available, falling back to microphone', tabError);
|
||||||
|
const audioConstraints = selectedMicrophoneId
|
||||||
|
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||||
|
: { audio: true };
|
||||||
|
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||||
|
statusText.textContent = "Using microphone audio.";
|
||||||
|
}
|
||||||
|
} else if (isWebContext) {
|
||||||
|
const audioConstraints = selectedMicrophoneId
|
||||||
|
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||||
|
: { audio: true };
|
||||||
|
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||||
|
}
|
||||||
|
|
||||||
|
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||||
|
analyser = audioContext.createAnalyser();
|
||||||
|
analyser.fftSize = 256;
|
||||||
|
microphone = audioContext.createMediaStreamSource(stream);
|
||||||
|
microphone.connect(analyser);
|
||||||
|
|
||||||
|
if (serverUseAudioWorklet) {
|
||||||
|
if (!audioContext.audioWorklet) {
|
||||||
|
throw new Error("AudioWorklet is not supported in this browser");
|
||||||
|
}
|
||||||
|
await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");
|
||||||
|
workletNode = new AudioWorkletNode(audioContext, "pcm-forwarder", { numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1 });
|
||||||
|
microphone.connect(workletNode);
|
||||||
|
|
||||||
|
recorderWorker = new Worker("/web/recorder_worker.js");
|
||||||
|
recorderWorker.postMessage({
|
||||||
|
command: "init",
|
||||||
|
config: {
|
||||||
|
sampleRate: audioContext.sampleRate,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
recorderWorker.onmessage = (e) => {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
websocket.send(e.data.buffer);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
workletNode.port.onmessage = (e) => {
|
||||||
|
const data = e.data;
|
||||||
|
const ab = data instanceof ArrayBuffer ? data : data.buffer;
|
||||||
|
recorderWorker.postMessage(
|
||||||
|
{
|
||||||
|
command: "record",
|
||||||
|
buffer: ab,
|
||||||
|
},
|
||||||
|
[ab]
|
||||||
|
);
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||||
|
} catch (e) {
|
||||||
|
recorder = new MediaRecorder(stream);
|
||||||
|
}
|
||||||
|
recorder.ondataavailable = (e) => {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
if (e.data && e.data.size > 0) {
|
||||||
|
websocket.send(e.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
recorder.start(chunkDuration);
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime = Date.now();
|
||||||
|
timerInterval = setInterval(updateTimer, 1000);
|
||||||
|
drawWaveform();
|
||||||
|
|
||||||
|
isRecording = true;
|
||||||
|
updateUI();
|
||||||
|
} catch (err) {
|
||||||
|
if (window.location.hostname === "0.0.0.0") {
|
||||||
|
statusText.textContent =
|
||||||
|
"Error accessing microphone. Browsers may block microphone access on 0.0.0.0. Try using localhost:8000 instead.";
|
||||||
|
} else {
|
||||||
|
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
||||||
|
}
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function stopRecording() {
|
||||||
|
if (wakeLock) {
|
||||||
|
try {
|
||||||
|
await wakeLock.release();
|
||||||
|
} catch (e) {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
wakeLock = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
userClosing = true;
|
||||||
|
waitingForStop = true;
|
||||||
|
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
const emptyBlob = new Blob([], { type: "audio/webm" });
|
||||||
|
websocket.send(emptyBlob);
|
||||||
|
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recorder) {
|
||||||
|
try {
|
||||||
|
recorder.stop();
|
||||||
|
} catch (e) {
|
||||||
|
}
|
||||||
|
recorder = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (recorderWorker) {
|
||||||
|
recorderWorker.terminate();
|
||||||
|
recorderWorker = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (workletNode) {
|
||||||
|
try {
|
||||||
|
workletNode.port.onmessage = null;
|
||||||
|
} catch (e) {}
|
||||||
|
try {
|
||||||
|
workletNode.disconnect();
|
||||||
|
} catch (e) {}
|
||||||
|
workletNode = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (microphone) {
|
||||||
|
microphone.disconnect();
|
||||||
|
microphone = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (analyser) {
|
||||||
|
analyser = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (audioContext && audioContext.state !== "closed") {
|
||||||
|
try {
|
||||||
|
await audioContext.close();
|
||||||
|
} catch (e) {
|
||||||
|
console.warn("Could not close audio context:", e);
|
||||||
|
}
|
||||||
|
audioContext = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (audioSource) {
|
||||||
|
audioSource.disconnect();
|
||||||
|
audioSource = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (outputAudioContext && outputAudioContext.state !== "closed") {
|
||||||
|
outputAudioContext.close()
|
||||||
|
outputAudioContext = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (animationFrame) {
|
||||||
|
cancelAnimationFrame(animationFrame);
|
||||||
|
animationFrame = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timerInterval) {
|
||||||
|
clearInterval(timerInterval);
|
||||||
|
timerInterval = null;
|
||||||
|
}
|
||||||
|
timerElement.textContent = "00:00";
|
||||||
|
startTime = null;
|
||||||
|
|
||||||
|
isRecording = false;
|
||||||
|
updateUI();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function toggleRecording() {
|
||||||
|
if (!isRecording) {
|
||||||
|
if (waitingForStop) {
|
||||||
|
console.log("Waiting for stop, early return");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
console.log("Connecting to WebSocket");
|
||||||
|
try {
|
||||||
|
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||||
|
await configReady;
|
||||||
|
await startRecording();
|
||||||
|
} else {
|
||||||
|
await setupWebSocket();
|
||||||
|
await configReady;
|
||||||
|
await startRecording();
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||||
|
console.error(err);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.log("Stopping recording");
|
||||||
|
stopRecording();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateUI() {
|
||||||
|
recordButton.classList.toggle("recording", isRecording);
|
||||||
|
recordButton.disabled = waitingForStop;
|
||||||
|
|
||||||
|
if (waitingForStop) {
|
||||||
|
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
||||||
|
statusText.textContent = "Please wait for processing to complete...";
|
||||||
|
}
|
||||||
|
} else if (isRecording) {
|
||||||
|
statusText.textContent = "";
|
||||||
|
} else {
|
||||||
|
if (
|
||||||
|
statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||||
|
statusText.textContent !== "Processing finalized or connection closed."
|
||||||
|
) {
|
||||||
|
statusText.textContent = "Click to start transcription";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!waitingForStop) {
|
||||||
|
recordButton.disabled = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
recordButton.addEventListener("click", toggleRecording);
|
||||||
|
|
||||||
|
if (microphoneSelect) {
|
||||||
|
microphoneSelect.addEventListener("change", handleMicrophoneChange);
|
||||||
|
}
|
||||||
|
document.addEventListener('DOMContentLoaded', async () => {
|
||||||
|
try {
|
||||||
|
await enumerateMicrophones();
|
||||||
|
} catch (error) {
|
||||||
|
console.log("Could not enumerate microphones on load:", error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
navigator.mediaDevices.addEventListener('devicechange', async () => {
|
||||||
|
console.log('Device change detected, re-enumerating microphones');
|
||||||
|
try {
|
||||||
|
await enumerateMicrophones();
|
||||||
|
} catch (error) {
|
||||||
|
console.log("Error re-enumerating microphones:", error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
settingsToggle.addEventListener("click", () => {
|
||||||
|
settingsDiv.classList.toggle("visible");
|
||||||
|
settingsToggle.classList.toggle("active");
|
||||||
|
});
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
async function checkAndRequestPermissions() {
|
||||||
|
const micPermission = await navigator.permissions.query({
|
||||||
|
name: "microphone",
|
||||||
|
});
|
||||||
|
|
||||||
|
const permissionDisplay = document.getElementById("audioPermission");
|
||||||
|
if (permissionDisplay) {
|
||||||
|
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if (micPermission.state !== "granted") {
|
||||||
|
// chrome.tabs.create({ url: "welcome.html" });
|
||||||
|
// }
|
||||||
|
|
||||||
|
const intervalId = setInterval(async () => {
|
||||||
|
const micPermission = await navigator.permissions.query({
|
||||||
|
name: "microphone",
|
||||||
|
});
|
||||||
|
if (micPermission.state === "granted") {
|
||||||
|
if (permissionDisplay) {
|
||||||
|
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
|
||||||
|
}
|
||||||
|
clearInterval(intervalId);
|
||||||
|
}
|
||||||
|
}, 100);
|
||||||
|
}
|
||||||
|
|
||||||
|
void checkAndRequestPermissions();
|
||||||
|
}
|
||||||
16
whisperlivekit/web/pcm_worklet.js
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
class PCMForwarder extends AudioWorkletProcessor {
|
||||||
|
process(inputs) {
|
||||||
|
const input = inputs[0];
|
||||||
|
if (input && input[0] && input[0].length) {
|
||||||
|
// Forward mono channel (0). If multi-channel, downmixing can be added here.
|
||||||
|
const channelData = input[0];
|
||||||
|
const copy = new Float32Array(channelData.length);
|
||||||
|
copy.set(channelData);
|
||||||
|
this.port.postMessage(copy, [copy.buffer]);
|
||||||
|
}
|
||||||
|
// Keep processor alive
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
registerProcessor('pcm-forwarder', PCMForwarder);
|
||||||
58
whisperlivekit/web/recorder_worker.js
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
let sampleRate = 48000;
|
||||||
|
let targetSampleRate = 16000;
|
||||||
|
|
||||||
|
self.onmessage = function (e) {
|
||||||
|
switch (e.data.command) {
|
||||||
|
case 'init':
|
||||||
|
init(e.data.config);
|
||||||
|
break;
|
||||||
|
case 'record':
|
||||||
|
record(e.data.buffer);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
function init(config) {
|
||||||
|
sampleRate = config.sampleRate;
|
||||||
|
targetSampleRate = config.targetSampleRate || 16000;
|
||||||
|
}
|
||||||
|
|
||||||
|
function record(inputBuffer) {
|
||||||
|
const buffer = new Float32Array(inputBuffer);
|
||||||
|
const resampledBuffer = resample(buffer, sampleRate, targetSampleRate);
|
||||||
|
const pcmBuffer = toPCM(resampledBuffer);
|
||||||
|
self.postMessage({ buffer: pcmBuffer }, [pcmBuffer]);
|
||||||
|
}
|
||||||
|
|
||||||
|
function resample(buffer, from, to) {
|
||||||
|
if (from === to) {
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
const ratio = from / to;
|
||||||
|
const newLength = Math.round(buffer.length / ratio);
|
||||||
|
const result = new Float32Array(newLength);
|
||||||
|
let offsetResult = 0;
|
||||||
|
let offsetBuffer = 0;
|
||||||
|
while (offsetResult < result.length) {
|
||||||
|
const nextOffsetBuffer = Math.round((offsetResult + 1) * ratio);
|
||||||
|
let accum = 0, count = 0;
|
||||||
|
for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
|
||||||
|
accum += buffer[i];
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
result[offsetResult] = accum / count;
|
||||||
|
offsetResult++;
|
||||||
|
offsetBuffer = nextOffsetBuffer;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
function toPCM(input) {
|
||||||
|
const buffer = new ArrayBuffer(input.length * 2);
|
||||||
|
const view = new DataView(buffer);
|
||||||
|
for (let i = 0; i < input.length; i++) {
|
||||||
|
const s = Math.max(-1, Math.min(1, input[i]));
|
||||||
|
view.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
1
whisperlivekit/web/src/dark_mode.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 493 B |
1
whisperlivekit/web/src/language.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 976 B |
1
whisperlivekit/web/src/light_mode.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 1.2 KiB |
1
whisperlivekit/web/src/settings.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 982 B |
1
whisperlivekit/web/src/silence.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 984 B |
1
whisperlivekit/web/src/speaker.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 592 B |
1
whisperlivekit/web/src/system_mode.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 1.4 KiB |
1
whisperlivekit/web/src/translate.svg
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>
|
||||||
|
After Width: | Height: | Size: 650 B |
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import importlib.resources as resources
|
import importlib.resources as resources
|
||||||
|
import base64
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -10,4 +11,104 @@ def get_web_interface_html():
|
|||||||
return f.read()
|
return f.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading web interface HTML: {e}")
|
logger.error(f"Error loading web interface HTML: {e}")
|
||||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||||
|
|
||||||
|
def get_inline_ui_html():
|
||||||
|
"""Returns the complete web interface HTML with all assets embedded in a single call."""
|
||||||
|
try:
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
|
||||||
|
html_content = f.read()
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
|
||||||
|
css_content = f.read()
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
|
||||||
|
js_content = f.read()
|
||||||
|
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
|
||||||
|
worklet_code = f.read()
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
|
||||||
|
worker_code = f.read()
|
||||||
|
|
||||||
|
js_content = js_content.replace(
|
||||||
|
'await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");',
|
||||||
|
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
|
||||||
|
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
|
||||||
|
'await audioContext.audioWorklet.addModule(workletUrl);'
|
||||||
|
)
|
||||||
|
js_content = js_content.replace(
|
||||||
|
'recorderWorker = new Worker("/web/recorder_worker.js");',
|
||||||
|
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
|
||||||
|
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
|
||||||
|
'recorderWorker = new Worker(workerUrl);'
|
||||||
|
)
|
||||||
|
|
||||||
|
# SVG files
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
|
||||||
|
system_svg = f.read()
|
||||||
|
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
|
||||||
|
light_svg = f.read()
|
||||||
|
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
|
||||||
|
dark_svg = f.read()
|
||||||
|
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
|
||||||
|
with resources.files('whisperlivekit.web').joinpath('src', 'settings.svg').open('r', encoding='utf-8') as f:
|
||||||
|
settings = f.read()
|
||||||
|
settings_uri = f"data:image/svg+xml;base64,{base64.b64encode(settings.encode('utf-8')).decode('utf-8')}"
|
||||||
|
|
||||||
|
# Replace external references
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<link rel="stylesheet" href="live_transcription.css" />',
|
||||||
|
f'<style>\n{css_content}\n</style>'
|
||||||
|
)
|
||||||
|
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<script src="live_transcription.js"></script>',
|
||||||
|
f'<script>\n{js_content}\n</script>'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace SVG references
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<img src="/web/src/system_mode.svg" alt="" />',
|
||||||
|
f'<img src="{system_data_uri}" alt="" />'
|
||||||
|
)
|
||||||
|
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<img src="/web/src/light_mode.svg" alt="" />',
|
||||||
|
f'<img src="{light_data_uri}" alt="" />'
|
||||||
|
)
|
||||||
|
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<img src="/web/src/dark_mode.svg" alt="" />',
|
||||||
|
f'<img src="{dark_data_uri}" alt="" />'
|
||||||
|
)
|
||||||
|
|
||||||
|
html_content = html_content.replace(
|
||||||
|
'<img src="web/src/settings.svg" alt="Settings" />',
|
||||||
|
f'<img src="{settings_uri}" alt="" />'
|
||||||
|
)
|
||||||
|
|
||||||
|
return html_content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating embedded web interface: {e}")
|
||||||
|
return "<html><body><h1>Error loading embedded interface</h1></body></html>"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
import uvicorn
|
||||||
|
from starlette.staticfiles import StaticFiles
|
||||||
|
import pathlib
|
||||||
|
import whisperlivekit.web as webpkg
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
web_dir = pathlib.Path(webpkg.__file__).parent
|
||||||
|
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def get():
|
||||||
|
return HTMLResponse(get_inline_ui_html())
|
||||||
|
|
||||||
|
uvicorn.run(app=app)
|
||||||
|
|||||||
@@ -3,41 +3,22 @@ import logging
|
|||||||
import io
|
import io
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import math
|
import math
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
except ImportError:
|
|
||||||
torch = None
|
|
||||||
from typing import List
|
from typing import List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from whisperlivekit.timed_objects import ASRToken
|
from whisperlivekit.timed_objects import ASRToken
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
|
||||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
|
||||||
from whisperlivekit.simul_whisper.simul_whisper import PaddedAlignAttWhisper, DEC_PAD
|
|
||||||
from whisperlivekit.simul_whisper.whisper import tokenizer
|
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("SimulStreaming dependencies not available. SimulStreaming backend will not be available.")
|
|
||||||
SIMULSTREAMING_AVAILABLE = False
|
|
||||||
AlignAttConfig = None
|
|
||||||
PaddedAlignAttWhisper = None
|
|
||||||
DEC_PAD = None
|
|
||||||
tokenizer = None
|
|
||||||
|
|
||||||
class ASRBase:
|
class ASRBase:
|
||||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||||
# "" for faster-whisper because it emits the spaces when needed)
|
# "" for faster-whisper because it emits the spaces when needed)
|
||||||
|
|
||||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.transcribe_kargs = {}
|
self.transcribe_kargs = {}
|
||||||
if lan == "auto":
|
if lan == "auto":
|
||||||
self.original_language = None
|
self.original_language = None
|
||||||
else:
|
else:
|
||||||
self.original_language = lan
|
self.original_language = lan
|
||||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
self.model = self.load_model(model_size, cache_dir, model_dir)
|
||||||
|
|
||||||
def with_offset(self, offset: float) -> ASRToken:
|
def with_offset(self, offset: float) -> ASRToken:
|
||||||
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
|
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
|
||||||
@@ -46,7 +27,7 @@ class ASRBase:
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||||
|
|
||||||
def load_model(self, modelsize, cache_dir, model_dir):
|
def load_model(self, model_size, cache_dir, model_dir):
|
||||||
raise NotImplementedError("must be implemented in the child class")
|
raise NotImplementedError("must be implemented in the child class")
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
def transcribe(self, audio, init_prompt=""):
|
||||||
@@ -60,7 +41,7 @@ class WhisperTimestampedASR(ASRBase):
|
|||||||
"""Uses whisper_timestamped as the backend."""
|
"""Uses whisper_timestamped as the backend."""
|
||||||
sep = " "
|
sep = " "
|
||||||
|
|
||||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||||
import whisper
|
import whisper
|
||||||
import whisper_timestamped
|
import whisper_timestamped
|
||||||
from whisper_timestamped import transcribe_timestamped
|
from whisper_timestamped import transcribe_timestamped
|
||||||
@@ -68,7 +49,7 @@ class WhisperTimestampedASR(ASRBase):
|
|||||||
self.transcribe_timestamped = transcribe_timestamped
|
self.transcribe_timestamped = transcribe_timestamped
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
logger.debug("ignoring model_dir, not implemented")
|
logger.debug("ignoring model_dir, not implemented")
|
||||||
return whisper.load_model(modelsize, download_root=cache_dir)
|
return whisper.load_model(model_size, download_root=cache_dir)
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
def transcribe(self, audio, init_prompt=""):
|
||||||
result = self.transcribe_timestamped(
|
result = self.transcribe_timestamped(
|
||||||
@@ -107,17 +88,17 @@ class FasterWhisperASR(ASRBase):
|
|||||||
"""Uses faster-whisper as the backend."""
|
"""Uses faster-whisper as the backend."""
|
||||||
sep = ""
|
sep = ""
|
||||||
|
|
||||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
||||||
f"modelsize and cache_dir parameters are not used.")
|
f"model_size and cache_dir parameters are not used.")
|
||||||
model_size_or_path = model_dir
|
model_size_or_path = model_dir
|
||||||
elif modelsize is not None:
|
elif model_size is not None:
|
||||||
model_size_or_path = modelsize
|
model_size_or_path = model_size
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either modelsize or model_dir must be set")
|
raise ValueError("Either model_size or model_dir must be set")
|
||||||
device = "auto" # Allow CTranslate2 to decide available device
|
device = "auto" # Allow CTranslate2 to decide available device
|
||||||
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
||||||
|
|
||||||
@@ -168,18 +149,18 @@ class MLXWhisper(ASRBase):
|
|||||||
"""
|
"""
|
||||||
sep = ""
|
sep = ""
|
||||||
|
|
||||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
if model_dir is not None:
|
if model_dir is not None:
|
||||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
logger.debug(f"Loading whisper model from model_dir {model_dir}. model_size parameter is not used.")
|
||||||
model_size_or_path = model_dir
|
model_size_or_path = model_dir
|
||||||
elif modelsize is not None:
|
elif model_size is not None:
|
||||||
model_size_or_path = self.translate_model_name(modelsize)
|
model_size_or_path = self.translate_model_name(model_size)
|
||||||
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
|
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either modelsize or model_dir must be set")
|
raise ValueError("Either model_size or model_dir must be set")
|
||||||
|
|
||||||
self.model_size_or_path = model_size_or_path
|
self.model_size_or_path = model_size_or_path
|
||||||
dtype = mx.float16
|
dtype = mx.float16
|
||||||
@@ -306,181 +287,4 @@ class OpenaiApiASR(ASRBase):
|
|||||||
self.use_vad_opt = True
|
self.use_vad_opt = True
|
||||||
|
|
||||||
def set_translate_task(self):
|
def set_translate_task(self):
|
||||||
self.task = "translate"
|
self.task = "translate"
|
||||||
|
|
||||||
|
|
||||||
class SimulStreamingASR(ASRBase):
|
|
||||||
"""SimulStreaming backend with AlignAtt policy."""
|
|
||||||
sep = " "
|
|
||||||
|
|
||||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr, **kwargs):
|
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise ImportError("""SimulStreaming dependencies are not available. Please install WhisperLiveKit using pip install "whisperlivekit[simulstreaming]". If you are building from source, you should also copy the content of the simul_whisper directory from the SimulStreaming repository into whisperlivekit/simul_whisper.""")
|
|
||||||
with open("whisperlivekit/simul_whisper/dual_license_simulstreaming.md", "r") as f:
|
|
||||||
print("*"*80 + f.read() + "*"*80)
|
|
||||||
self.logfile = logfile
|
|
||||||
self.transcribe_kargs = {}
|
|
||||||
self.original_language = None if lan == "auto" else lan
|
|
||||||
|
|
||||||
self.model_path = kwargs.get('model_path', './large-v3.pt')
|
|
||||||
self.frame_threshold = kwargs.get('frame_threshold', 25)
|
|
||||||
self.audio_max_len = kwargs.get('audio_max_len', 30.0)
|
|
||||||
self.audio_min_len = kwargs.get('audio_min_len', 0.0)
|
|
||||||
self.segment_length = kwargs.get('segment_length', 0.5)
|
|
||||||
self.beams = kwargs.get('beams', 1)
|
|
||||||
self.decoder_type = kwargs.get('decoder_type', 'greedy' if self.beams == 1 else 'beam')
|
|
||||||
self.task = kwargs.get('task', 'transcribe')
|
|
||||||
self.cif_ckpt_path = kwargs.get('cif_ckpt_path', None)
|
|
||||||
self.never_fire = kwargs.get('never_fire', False)
|
|
||||||
self.init_prompt = kwargs.get('init_prompt', None)
|
|
||||||
self.static_init_prompt = kwargs.get('static_init_prompt', None)
|
|
||||||
self.max_context_tokens = kwargs.get('max_context_tokens', None)
|
|
||||||
|
|
||||||
if model_dir is not None:
|
|
||||||
self.model_path = model_dir
|
|
||||||
elif modelsize is not None: #For the moment the .en.pt models do not work!
|
|
||||||
model_mapping = {
|
|
||||||
'tiny': './tiny.pt',
|
|
||||||
'base': './base.pt',
|
|
||||||
'small': './small.pt',
|
|
||||||
'medium': './medium.pt',
|
|
||||||
'medium.en': './medium.en.pt',
|
|
||||||
'large-v1': './large-v1.pt',
|
|
||||||
'base.en': './base.en.pt',
|
|
||||||
'small.en': './small.en.pt',
|
|
||||||
'tiny.en': './tiny.en.pt',
|
|
||||||
'large-v2': './large-v2.pt',
|
|
||||||
'large-v3': './large-v3.pt',
|
|
||||||
'large': './large-v3.pt'
|
|
||||||
}
|
|
||||||
self.model_path = model_mapping.get(modelsize, f'./{modelsize}.pt')
|
|
||||||
|
|
||||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
|
||||||
|
|
||||||
# Set up tokenizer for translation if needed
|
|
||||||
if self.task == "translate":
|
|
||||||
self.set_translate_task()
|
|
||||||
|
|
||||||
def load_model(self, modelsize, cache_dir, model_dir):
|
|
||||||
try:
|
|
||||||
cfg = AlignAttConfig(
|
|
||||||
model_path=self.model_path,
|
|
||||||
segment_length=self.segment_length,
|
|
||||||
frame_threshold=self.frame_threshold,
|
|
||||||
language=self.original_language,
|
|
||||||
audio_max_len=self.audio_max_len,
|
|
||||||
audio_min_len=self.audio_min_len,
|
|
||||||
cif_ckpt_path=self.cif_ckpt_path,
|
|
||||||
decoder_type="beam",
|
|
||||||
beam_size=self.beams,
|
|
||||||
task=self.task,
|
|
||||||
never_fire=self.never_fire,
|
|
||||||
init_prompt=self.init_prompt,
|
|
||||||
max_context_tokens=self.max_context_tokens,
|
|
||||||
static_init_prompt=self.static_init_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Loading SimulStreaming model with language: {self.original_language}")
|
|
||||||
model = PaddedAlignAttWhisper(cfg)
|
|
||||||
return model
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load SimulStreaming model: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def transcribe(self, audio, init_prompt=""):
|
|
||||||
"""Transcribe audio using SimulStreaming."""
|
|
||||||
try:
|
|
||||||
if isinstance(audio, np.ndarray):
|
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
|
||||||
else:
|
|
||||||
audio_tensor = audio
|
|
||||||
|
|
||||||
prompt = init_prompt if init_prompt else (self.init_prompt or "")
|
|
||||||
|
|
||||||
result = self.model.infer(audio_tensor, init_prompt=prompt)
|
|
||||||
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
result = result[result < DEC_PAD]
|
|
||||||
|
|
||||||
logger.debug(f"SimulStreaming transcription result: {result}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"SimulStreaming transcription failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def ts_words(self, result) -> List[ASRToken]:
|
|
||||||
"""Convert SimulStreaming result to ASRToken list."""
|
|
||||||
tokens = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
text = self.model.tokenizer.decode(result.cpu().numpy())
|
|
||||||
else:
|
|
||||||
text = str(result)
|
|
||||||
|
|
||||||
if not text or len(text.strip()) == 0:
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
# We dont have word-level timestamps here. 1rst approach, should be improved later.
|
|
||||||
words = text.strip().split()
|
|
||||||
if not words:
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
duration_per_word = 0.1 # this will be modified based on actual audio duration
|
|
||||||
#with the SimulStreamingOnlineProcessor
|
|
||||||
|
|
||||||
for i, word in enumerate(words):
|
|
||||||
start_time = i * duration_per_word
|
|
||||||
end_time = (i + 1) * duration_per_word
|
|
||||||
|
|
||||||
token = ASRToken(
|
|
||||||
start=start_time,
|
|
||||||
end=end_time,
|
|
||||||
text=word,
|
|
||||||
probability=1.0
|
|
||||||
)
|
|
||||||
tokens.append(token)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error converting SimulStreaming result to tokens: {e}")
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def segments_end_ts(self, result) -> List[float]:
|
|
||||||
"""Get segment end timestamps."""
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
num_tokens = len(result)
|
|
||||||
return [num_tokens * 0.1] # rough estimate
|
|
||||||
return [1.0]
|
|
||||||
|
|
||||||
def use_vad(self):
|
|
||||||
"""Enable VAD - SimulStreaming has different VAD handling."""
|
|
||||||
logger.info("VAD requested for SimulStreaming - handled internally by the model")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_translate_task(self):
|
|
||||||
"""Set up translation task."""
|
|
||||||
try:
|
|
||||||
self.model.tokenizer = tokenizer.get_tokenizer(
|
|
||||||
multilingual=True,
|
|
||||||
language=self.model.cfg.language,
|
|
||||||
num_languages=self.model.model.num_languages,
|
|
||||||
task="translate"
|
|
||||||
)
|
|
||||||
logger.info("SimulStreaming configured for translation task")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to configure SimulStreaming for translation: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def warmup(self, audio, init_prompt=""):
|
|
||||||
"""Warmup the SimulStreaming model."""
|
|
||||||
try:
|
|
||||||
if isinstance(audio, np.ndarray):
|
|
||||||
audio = torch.from_numpy(audio).float()
|
|
||||||
self.model.infer(audio, True)
|
|
||||||
self.model.refresh_segment(complete=True)
|
|
||||||
logger.info("SimulStreaming model warmed up successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"SimulStreaming warmup failed: {e}")
|
|
||||||
@@ -6,18 +6,6 @@ from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# simulStreaming imports - we check if the files are here
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
from simul_whisper.config import AlignAttConfig
|
|
||||||
SIMULSTREAMING_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("SimulStreaming dependencies not available for online processor.")
|
|
||||||
SIMULSTREAMING_AVAILABLE = False
|
|
||||||
OnlineProcessorInterface = None
|
|
||||||
torch = None
|
|
||||||
|
|
||||||
|
|
||||||
class HypothesisBuffer:
|
class HypothesisBuffer:
|
||||||
"""
|
"""
|
||||||
Buffer to store and process ASR hypothesis tokens.
|
Buffer to store and process ASR hypothesis tokens.
|
||||||
@@ -118,9 +106,6 @@ class OnlineASRProcessor:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
asr,
|
asr,
|
||||||
tokenize_method: Optional[callable] = None,
|
|
||||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
|
||||||
confidence_validation = False,
|
|
||||||
logfile=sys.stderr,
|
logfile=sys.stderr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -131,12 +116,14 @@ class OnlineASRProcessor:
|
|||||||
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
|
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
|
||||||
"""
|
"""
|
||||||
self.asr = asr
|
self.asr = asr
|
||||||
self.tokenize = tokenize_method
|
self.tokenize = asr.tokenizer
|
||||||
self.logfile = logfile
|
self.logfile = logfile
|
||||||
self.confidence_validation = confidence_validation
|
self.confidence_validation = asr.confidence_validation
|
||||||
|
self.global_time_offset = 0.0
|
||||||
self.init()
|
self.init()
|
||||||
|
|
||||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
self.buffer_trimming_way = asr.buffer_trimming
|
||||||
|
self.buffer_trimming_sec = asr.buffer_trimming_sec
|
||||||
|
|
||||||
if self.buffer_trimming_way not in ["sentence", "segment"]:
|
if self.buffer_trimming_way not in ["sentence", "segment"]:
|
||||||
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
|
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
|
||||||
@@ -154,6 +141,7 @@ class OnlineASRProcessor:
|
|||||||
self.buffer_time_offset = offset if offset is not None else 0.0
|
self.buffer_time_offset = offset if offset is not None else 0.0
|
||||||
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
||||||
self.committed: List[ASRToken] = []
|
self.committed: List[ASRToken] = []
|
||||||
|
self.time_of_last_asr_output = 0.0
|
||||||
|
|
||||||
def get_audio_buffer_end_time(self) -> float:
|
def get_audio_buffer_end_time(self) -> float:
|
||||||
"""Returns the absolute end time of the current audio_buffer."""
|
"""Returns the absolute end time of the current audio_buffer."""
|
||||||
@@ -163,6 +151,21 @@ class OnlineASRProcessor:
|
|||||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||||
|
|
||||||
|
def insert_silence(self, silence_duration, offset):
|
||||||
|
"""
|
||||||
|
If silences are > 5s, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||||
|
"""
|
||||||
|
# if self.transcript_buffer.buffer:
|
||||||
|
# self.committed.extend(self.transcript_buffer.buffer)
|
||||||
|
# self.transcript_buffer.buffer = []
|
||||||
|
|
||||||
|
if True: #silence_duration < 3: #we want the last audio to be treated to not have a gap. could also be handled in the future in ends_with_silence.
|
||||||
|
gap_silence = np.zeros(int(16000 * silence_duration), dtype=np.int16)
|
||||||
|
self.insert_audio_chunk(gap_silence)
|
||||||
|
else:
|
||||||
|
self.init(offset=silence_duration + offset)
|
||||||
|
self.global_time_offset += silence_duration
|
||||||
|
|
||||||
def prompt(self) -> Tuple[str, str]:
|
def prompt(self) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Returns a tuple: (prompt, context), where:
|
Returns a tuple: (prompt, context), where:
|
||||||
@@ -210,11 +213,26 @@ class OnlineASRProcessor:
|
|||||||
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
|
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
|
||||||
committed_tokens = self.transcript_buffer.flush()
|
committed_tokens = self.transcript_buffer.flush()
|
||||||
self.committed.extend(committed_tokens)
|
self.committed.extend(committed_tokens)
|
||||||
|
|
||||||
|
if committed_tokens:
|
||||||
|
self.time_of_last_asr_output = self.committed[-1].end
|
||||||
|
|
||||||
completed = self.concatenate_tokens(committed_tokens)
|
completed = self.concatenate_tokens(committed_tokens)
|
||||||
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
|
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
|
||||||
incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
|
incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
|
||||||
logger.debug(f"INCOMPLETE: {incomp.text}")
|
logger.debug(f"INCOMPLETE: {incomp.text}")
|
||||||
|
|
||||||
|
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||||
|
if not committed_tokens and buffer_duration > self.buffer_trimming_sec:
|
||||||
|
time_since_last_output = self.get_audio_buffer_end_time() - self.time_of_last_asr_output
|
||||||
|
if time_since_last_output > self.buffer_trimming_sec:
|
||||||
|
logger.warning(
|
||||||
|
f"No ASR output for {time_since_last_output:.2f}s. "
|
||||||
|
f"Resetting buffer to prevent freezing."
|
||||||
|
)
|
||||||
|
self.init(offset=self.get_audio_buffer_end_time())
|
||||||
|
return [], current_audio_processed_upto
|
||||||
|
|
||||||
if committed_tokens and self.buffer_trimming_way == "sentence":
|
if committed_tokens and self.buffer_trimming_way == "sentence":
|
||||||
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
|
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
|
||||||
self.chunk_completed_sentence()
|
self.chunk_completed_sentence()
|
||||||
@@ -226,6 +244,9 @@ class OnlineASRProcessor:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||||
)
|
)
|
||||||
|
if self.global_time_offset:
|
||||||
|
for token in committed_tokens:
|
||||||
|
token = token.with_offset(self.global_time_offset)
|
||||||
return committed_tokens, current_audio_processed_upto
|
return committed_tokens, current_audio_processed_upto
|
||||||
|
|
||||||
def chunk_completed_sentence(self):
|
def chunk_completed_sentence(self):
|
||||||
@@ -387,331 +408,3 @@ class OnlineASRProcessor:
|
|||||||
start = None
|
start = None
|
||||||
end = None
|
end = None
|
||||||
return Transcript(start, end, text, probability=probability)
|
return Transcript(start, end, text, probability=probability)
|
||||||
|
|
||||||
|
|
||||||
class VACOnlineASRProcessor:
|
|
||||||
"""
|
|
||||||
Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
|
|
||||||
|
|
||||||
It receives small chunks of audio, applies VAD (e.g. with Silero),
|
|
||||||
and when the system detects a pause in speech (or end of an utterance)
|
|
||||||
it finalizes the utterance immediately.
|
|
||||||
"""
|
|
||||||
SAMPLING_RATE = 16000
|
|
||||||
|
|
||||||
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
|
||||||
self.online_chunk_size = online_chunk_size
|
|
||||||
self.online = OnlineASRProcessor(*args, **kwargs)
|
|
||||||
self.asr = self.online.asr
|
|
||||||
|
|
||||||
# Load a VAD model (e.g. Silero VAD)
|
|
||||||
import torch
|
|
||||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
|
||||||
from .silero_vad_iterator import FixedVADIterator
|
|
||||||
|
|
||||||
self.vac = FixedVADIterator(model)
|
|
||||||
self.logfile = self.online.logfile
|
|
||||||
self.last_input_audio_stream_end_time: float = 0.0
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
self.online.init()
|
|
||||||
self.vac.reset_states()
|
|
||||||
self.current_online_chunk_buffer_size = 0
|
|
||||||
self.last_input_audio_stream_end_time = self.online.buffer_time_offset
|
|
||||||
self.is_currently_final = False
|
|
||||||
self.status: Optional[str] = None # "voice" or "nonvoice"
|
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
|
||||||
self.buffer_offset = 0 # in frames
|
|
||||||
|
|
||||||
def get_audio_buffer_end_time(self) -> float:
|
|
||||||
"""Returns the absolute end time of the audio processed by the underlying OnlineASRProcessor."""
|
|
||||||
return self.online.get_audio_buffer_end_time()
|
|
||||||
|
|
||||||
def clear_buffer(self):
|
|
||||||
self.buffer_offset += len(self.audio_buffer)
|
|
||||||
self.audio_buffer = np.array([], dtype=np.float32)
|
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float):
|
|
||||||
"""
|
|
||||||
Process an incoming small audio chunk:
|
|
||||||
- run VAD on the chunk,
|
|
||||||
- decide whether to send the audio to the online ASR processor immediately,
|
|
||||||
- and/or to mark the current utterance as finished.
|
|
||||||
"""
|
|
||||||
self.last_input_audio_stream_end_time = audio_stream_end_time
|
|
||||||
res = self.vac(audio)
|
|
||||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
|
||||||
|
|
||||||
if res is not None:
|
|
||||||
# VAD returned a result; adjust the frame number
|
|
||||||
frame = list(res.values())[0] - self.buffer_offset
|
|
||||||
if "start" in res and "end" not in res:
|
|
||||||
self.status = "voice"
|
|
||||||
send_audio = self.audio_buffer[frame:]
|
|
||||||
self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
|
|
||||||
self.online.insert_audio_chunk(send_audio)
|
|
||||||
self.current_online_chunk_buffer_size += len(send_audio)
|
|
||||||
self.clear_buffer()
|
|
||||||
elif "end" in res and "start" not in res:
|
|
||||||
self.status = "nonvoice"
|
|
||||||
send_audio = self.audio_buffer[:frame]
|
|
||||||
self.online.insert_audio_chunk(send_audio)
|
|
||||||
self.current_online_chunk_buffer_size += len(send_audio)
|
|
||||||
self.is_currently_final = True
|
|
||||||
self.clear_buffer()
|
|
||||||
else:
|
|
||||||
beg = res["start"] - self.buffer_offset
|
|
||||||
end = res["end"] - self.buffer_offset
|
|
||||||
self.status = "nonvoice"
|
|
||||||
send_audio = self.audio_buffer[beg:end]
|
|
||||||
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
|
|
||||||
self.online.insert_audio_chunk(send_audio)
|
|
||||||
self.current_online_chunk_buffer_size += len(send_audio)
|
|
||||||
self.is_currently_final = True
|
|
||||||
self.clear_buffer()
|
|
||||||
else:
|
|
||||||
if self.status == "voice":
|
|
||||||
self.online.insert_audio_chunk(self.audio_buffer)
|
|
||||||
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
|
||||||
self.clear_buffer()
|
|
||||||
else:
|
|
||||||
# Keep 1 second worth of audio in case VAD later detects voice,
|
|
||||||
# but trim to avoid unbounded memory usage.
|
|
||||||
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
|
||||||
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
|
||||||
|
|
||||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
"""
|
|
||||||
Depending on the VAD status and the amount of accumulated audio,
|
|
||||||
process the current audio chunk.
|
|
||||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
|
||||||
"""
|
|
||||||
if self.is_currently_final:
|
|
||||||
return self.finish()
|
|
||||||
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
|
|
||||||
self.current_online_chunk_buffer_size = 0
|
|
||||||
return self.online.process_iter()
|
|
||||||
else:
|
|
||||||
logger.debug("No online update, only VAD")
|
|
||||||
return [], self.last_input_audio_stream_end_time
|
|
||||||
|
|
||||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
"""
|
|
||||||
Finish processing by flushing any remaining text.
|
|
||||||
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
|
|
||||||
"""
|
|
||||||
result_tokens, processed_upto = self.online.finish()
|
|
||||||
self.current_online_chunk_buffer_size = 0
|
|
||||||
self.is_currently_final = False
|
|
||||||
return result_tokens, processed_upto
|
|
||||||
|
|
||||||
def get_buffer(self):
|
|
||||||
"""
|
|
||||||
Get the unvalidated buffer in string format.
|
|
||||||
"""
|
|
||||||
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer)
|
|
||||||
|
|
||||||
|
|
||||||
class SimulStreamingOnlineProcessor:
|
|
||||||
SAMPLING_RATE = 16000
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
asr,
|
|
||||||
tokenize_method: Optional[callable] = None,
|
|
||||||
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
|
||||||
confidence_validation = False,
|
|
||||||
logfile=sys.stderr,
|
|
||||||
):
|
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise ImportError("SimulStreaming dependencies are not available.")
|
|
||||||
|
|
||||||
self.asr = asr
|
|
||||||
self.tokenize = tokenize_method
|
|
||||||
self.logfile = logfile
|
|
||||||
self.confidence_validation = confidence_validation
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
# buffer does not work yet
|
|
||||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
|
||||||
|
|
||||||
def init(self, offset: Optional[float] = None):
|
|
||||||
"""Initialize or reset the processing state."""
|
|
||||||
self.audio_chunks = []
|
|
||||||
self.offset = offset if offset is not None else 0.0
|
|
||||||
self.is_last = False
|
|
||||||
self.beg = self.offset
|
|
||||||
self.end = self.offset
|
|
||||||
self.cumulative_audio_duration = 0.0
|
|
||||||
self.last_audio_stream_end_time = self.offset
|
|
||||||
|
|
||||||
self.committed: List[ASRToken] = []
|
|
||||||
self.last_result_tokens: List[ASRToken] = []
|
|
||||||
self.buffer_content = ""
|
|
||||||
self.processed_audio_duration = 0.0
|
|
||||||
|
|
||||||
def get_audio_buffer_end_time(self) -> float:
|
|
||||||
"""Returns the absolute end time of the current audio buffer."""
|
|
||||||
return self.end
|
|
||||||
|
|
||||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
|
||||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
|
||||||
if torch is None:
|
|
||||||
raise ImportError("PyTorch is required for SimulStreaming but not available")
|
|
||||||
|
|
||||||
# Convert numpy array to torch tensor
|
|
||||||
audio_tensor = torch.from_numpy(audio).float()
|
|
||||||
self.audio_chunks.append(audio_tensor)
|
|
||||||
|
|
||||||
# Update timing
|
|
||||||
chunk_duration = len(audio) / self.SAMPLING_RATE
|
|
||||||
self.cumulative_audio_duration += chunk_duration
|
|
||||||
|
|
||||||
if audio_stream_end_time is not None:
|
|
||||||
self.last_audio_stream_end_time = audio_stream_end_time
|
|
||||||
self.end = audio_stream_end_time
|
|
||||||
else:
|
|
||||||
self.end = self.offset + self.cumulative_audio_duration
|
|
||||||
|
|
||||||
def prompt(self) -> Tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Returns a tuple: (prompt, context).
|
|
||||||
SimulStreaming handles prompting internally, so we return empty strings.
|
|
||||||
"""
|
|
||||||
return "", ""
|
|
||||||
|
|
||||||
def get_buffer(self):
|
|
||||||
"""
|
|
||||||
Get the unvalidated buffer content.
|
|
||||||
"""
|
|
||||||
buffer_end = self.end if hasattr(self, 'end') else None
|
|
||||||
return Transcript(
|
|
||||||
start=None,
|
|
||||||
end=buffer_end,
|
|
||||||
text=self.buffer_content,
|
|
||||||
probability=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
"""
|
|
||||||
Process accumulated audio chunks using SimulStreaming.
|
|
||||||
|
|
||||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
|
||||||
"""
|
|
||||||
if not self.audio_chunks:
|
|
||||||
return [], self.end
|
|
||||||
|
|
||||||
try:
|
|
||||||
# concatenate all audio chunks
|
|
||||||
if len(self.audio_chunks) == 1:
|
|
||||||
audio = self.audio_chunks[0]
|
|
||||||
else:
|
|
||||||
audio = torch.cat(self.audio_chunks, dim=0)
|
|
||||||
|
|
||||||
audio_duration = audio.shape[0] / self.SAMPLING_RATE if audio.shape[0] > 0 else 0
|
|
||||||
self.processed_audio_duration += audio_duration
|
|
||||||
|
|
||||||
self.audio_chunks = []
|
|
||||||
|
|
||||||
logger.debug(f"SimulStreaming processing audio shape: {audio.shape}, duration: {audio_duration:.2f}s")
|
|
||||||
logger.debug(f"Current end time: {self.end:.2f}s, last stream time: {self.last_audio_stream_end_time:.2f}s")
|
|
||||||
|
|
||||||
result = self.asr.model.infer(audio, is_last=self.is_last)
|
|
||||||
|
|
||||||
if torch.is_tensor(result):
|
|
||||||
# we filter out padding tokens as it s done in simul whisper
|
|
||||||
from simul_whisper.simul_whisper import DEC_PAD
|
|
||||||
result = result[result < DEC_PAD]
|
|
||||||
|
|
||||||
# C/P from simul_whisper.simul_whisper.py
|
|
||||||
if len(result) > 0:
|
|
||||||
decoded_text = self.asr.model.tokenizer.decode(result.cpu().numpy())
|
|
||||||
logger.debug(f"SimulStreaming decoded: {decoded_text}")
|
|
||||||
|
|
||||||
if decoded_text.strip():
|
|
||||||
words = decoded_text.strip().split()
|
|
||||||
new_tokens = []
|
|
||||||
|
|
||||||
num_words = len(words)
|
|
||||||
if num_words > 0:
|
|
||||||
# distribute words evenly across the processed audio duration
|
|
||||||
# we NEED that for when we use diarization. Even if that s not perfect
|
|
||||||
start_time = self.end - audio_duration
|
|
||||||
time_per_word = audio_duration / num_words if num_words > 1 else audio_duration
|
|
||||||
|
|
||||||
for i, word in enumerate(words):
|
|
||||||
token_start = start_time + (i * time_per_word)
|
|
||||||
token_end = start_time + ((i + 1) * time_per_word)
|
|
||||||
|
|
||||||
token_end = min(token_end, self.end)
|
|
||||||
|
|
||||||
token = ASRToken(
|
|
||||||
start=token_start,
|
|
||||||
end=token_end,
|
|
||||||
text=word,
|
|
||||||
probability=0.95 # fake prob. Maybe we can extract it from the model?
|
|
||||||
)
|
|
||||||
new_tokens.append(token)
|
|
||||||
|
|
||||||
self.beg = self.end
|
|
||||||
|
|
||||||
self.committed.extend(new_tokens)
|
|
||||||
self.last_result_tokens = new_tokens
|
|
||||||
|
|
||||||
logger.debug(f"SimulStreaming generated {len(new_tokens)} tokens with end time: {self.end:.2f}s")
|
|
||||||
return new_tokens, self.end
|
|
||||||
|
|
||||||
return [], self.end
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"SimulStreaming processing error: {e}")
|
|
||||||
logger.error(f"Error details: {type(e).__name__}: {str(e)}")
|
|
||||||
return [], self.end
|
|
||||||
|
|
||||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
|
||||||
logger.debug("SimulStreaming finish() called")
|
|
||||||
self.is_last = True
|
|
||||||
final_tokens, final_time = self.process_iter()
|
|
||||||
self.is_last = False
|
|
||||||
return final_tokens, final_time
|
|
||||||
|
|
||||||
def concatenate_tokens(
|
|
||||||
self,
|
|
||||||
tokens: List[ASRToken],
|
|
||||||
sep: Optional[str] = None,
|
|
||||||
offset: float = 0
|
|
||||||
) -> Transcript:
|
|
||||||
"""Concatenate tokens into a Transcript object."""
|
|
||||||
sep = sep if sep is not None else self.asr.sep
|
|
||||||
text = sep.join(token.text for token in tokens)
|
|
||||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
|
||||||
if tokens:
|
|
||||||
start = offset + tokens[0].start
|
|
||||||
end = offset + tokens[-1].end
|
|
||||||
else:
|
|
||||||
start = None
|
|
||||||
end = None
|
|
||||||
return Transcript(start, end, text, probability=probability)
|
|
||||||
|
|
||||||
def chunk_at(self, time: float):
|
|
||||||
"""
|
|
||||||
useless but kept for compatibility
|
|
||||||
"""
|
|
||||||
logger.debug(f"SimulStreaming chunk_at({time:.2f}) - handled internally")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
|
||||||
"""
|
|
||||||
Create simple sentences.
|
|
||||||
"""
|
|
||||||
if not tokens:
|
|
||||||
return []
|
|
||||||
|
|
||||||
full_text = " ".join(token.text for token in tokens)
|
|
||||||
sentence = Sentence(
|
|
||||||
start=tokens[0].start,
|
|
||||||
end=tokens[-1].end,
|
|
||||||
text=full_text
|
|
||||||
)
|
|
||||||
return [sentence]
|
|
||||||
|
|||||||
@@ -1,163 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
# This is copied from silero-vad's vad_utils.py:
|
|
||||||
# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
|
|
||||||
# (except changed defaults)
|
|
||||||
|
|
||||||
# Their licence is MIT, same as ours: https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/LICENSE
|
|
||||||
|
|
||||||
|
|
||||||
class VADIterator:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
threshold: float = 0.5,
|
|
||||||
sampling_rate: int = 16000,
|
|
||||||
min_silence_duration_ms: int = 500, # makes sense on one recording that I checked
|
|
||||||
speech_pad_ms: int = 100, # same
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Class for stream imitation
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
model: preloaded .jit silero VAD model
|
|
||||||
|
|
||||||
threshold: float (default - 0.5)
|
|
||||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
|
||||||
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
|
||||||
|
|
||||||
sampling_rate: int (default - 16000)
|
|
||||||
Currently silero VAD models support 8000 and 16000 sample rates
|
|
||||||
|
|
||||||
min_silence_duration_ms: int (default - 100 milliseconds)
|
|
||||||
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
|
||||||
|
|
||||||
speech_pad_ms: int (default - 30 milliseconds)
|
|
||||||
Final speech chunks are padded by speech_pad_ms each side
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.threshold = threshold
|
|
||||||
self.sampling_rate = sampling_rate
|
|
||||||
|
|
||||||
if sampling_rate not in [8000, 16000]:
|
|
||||||
raise ValueError(
|
|
||||||
"VADIterator does not support sampling rates other than [8000, 16000]"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
|
||||||
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
|
||||||
self.reset_states()
|
|
||||||
|
|
||||||
def reset_states(self):
|
|
||||||
|
|
||||||
self.model.reset_states()
|
|
||||||
self.triggered = False
|
|
||||||
self.temp_end = 0
|
|
||||||
self.current_sample = 0
|
|
||||||
|
|
||||||
def __call__(self, x, return_seconds=False):
|
|
||||||
"""
|
|
||||||
x: torch.Tensor
|
|
||||||
audio chunk (see examples in repo)
|
|
||||||
|
|
||||||
return_seconds: bool (default - False)
|
|
||||||
whether return timestamps in seconds (default - samples)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not torch.is_tensor(x):
|
|
||||||
try:
|
|
||||||
x = torch.Tensor(x)
|
|
||||||
except:
|
|
||||||
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
|
||||||
|
|
||||||
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
|
||||||
self.current_sample += window_size_samples
|
|
||||||
|
|
||||||
speech_prob = self.model(x, self.sampling_rate).item()
|
|
||||||
|
|
||||||
if (speech_prob >= self.threshold) and self.temp_end:
|
|
||||||
self.temp_end = 0
|
|
||||||
|
|
||||||
if (speech_prob >= self.threshold) and not self.triggered:
|
|
||||||
self.triggered = True
|
|
||||||
speech_start = self.current_sample - self.speech_pad_samples
|
|
||||||
return {
|
|
||||||
"start": (
|
|
||||||
int(speech_start)
|
|
||||||
if not return_seconds
|
|
||||||
else round(speech_start / self.sampling_rate, 1)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
|
||||||
if not self.temp_end:
|
|
||||||
self.temp_end = self.current_sample
|
|
||||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
speech_end = self.temp_end + self.speech_pad_samples
|
|
||||||
self.temp_end = 0
|
|
||||||
self.triggered = False
|
|
||||||
return {
|
|
||||||
"end": (
|
|
||||||
int(speech_end)
|
|
||||||
if not return_seconds
|
|
||||||
else round(speech_end / self.sampling_rate, 1)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
#######################
|
|
||||||
# because Silero now requires exactly 512-sized audio chunks
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class FixedVADIterator(VADIterator):
|
|
||||||
"""It fixes VADIterator by allowing to process any audio length, not only exactly 512 frames at once.
|
|
||||||
If audio to be processed at once is long and multiple voiced segments detected,
|
|
||||||
then __call__ returns the start of the first segment, and end (or middle, which means no end) of the last segment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def reset_states(self):
|
|
||||||
super().reset_states()
|
|
||||||
self.buffer = np.array([], dtype=np.float32)
|
|
||||||
|
|
||||||
def __call__(self, x, return_seconds=False):
|
|
||||||
self.buffer = np.append(self.buffer, x)
|
|
||||||
ret = None
|
|
||||||
while len(self.buffer) >= 512:
|
|
||||||
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
|
|
||||||
self.buffer = self.buffer[512:]
|
|
||||||
if ret is None:
|
|
||||||
ret = r
|
|
||||||
elif r is not None:
|
|
||||||
if "end" in r:
|
|
||||||
ret["end"] = r["end"] # the latter end
|
|
||||||
if "start" in r and "end" in ret: # there is an earlier start.
|
|
||||||
# Remove end, merging this segment with the previous one.
|
|
||||||
del ret["end"]
|
|
||||||
return ret if ret != {} else None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# test/demonstrate the need for FixedVADIterator:
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
|
||||||
vac = FixedVADIterator(model)
|
|
||||||
# vac = VADIterator(model) # the second case crashes with this
|
|
||||||
|
|
||||||
# this works: for both
|
|
||||||
audio_buffer = np.array([0] * (512), dtype=np.float32)
|
|
||||||
vac(audio_buffer)
|
|
||||||
|
|
||||||
# this crashes on the non FixedVADIterator with
|
|
||||||
# ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
|
|
||||||
audio_buffer = np.array([0] * (512 - 1), dtype=np.float32)
|
|
||||||
vac(audio_buffer)
|
|
||||||
@@ -5,8 +5,8 @@ import librosa
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR, SimulStreamingASR, SIMULSTREAMING_AVAILABLE
|
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
||||||
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor, SimulStreamingOnlineProcessor, SIMULSTREAMING_AVAILABLE as SIMULSTREAMING_ONLINE_AVAILABLE
|
from whisperlivekit.warmup import warmup_asr
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -64,42 +64,23 @@ def create_tokenizer(lan):
|
|||||||
return WtPtok()
|
return WtPtok()
|
||||||
|
|
||||||
|
|
||||||
def backend_factory(args):
|
def backend_factory(
|
||||||
backend = args.backend
|
backend,
|
||||||
|
lan,
|
||||||
|
model_size,
|
||||||
|
model_cache_dir,
|
||||||
|
model_dir,
|
||||||
|
task,
|
||||||
|
buffer_trimming,
|
||||||
|
buffer_trimming_sec,
|
||||||
|
confidence_validation,
|
||||||
|
warmup_file=None,
|
||||||
|
min_chunk_size=None,
|
||||||
|
):
|
||||||
|
backend = backend
|
||||||
if backend == "openai-api":
|
if backend == "openai-api":
|
||||||
logger.debug("Using OpenAI API.")
|
logger.debug("Using OpenAI API.")
|
||||||
asr = OpenaiApiASR(lan=args.lan)
|
asr = OpenaiApiASR(lan=lan)
|
||||||
elif backend == "simulstreaming":
|
|
||||||
logger.debug("Using SimulStreaming backend.")
|
|
||||||
if not SIMULSTREAMING_AVAILABLE:
|
|
||||||
raise ImportError(
|
|
||||||
"SimulStreaming backend is not available. Please install SimulStreaming dependencies. "
|
|
||||||
"See the documentation for installation instructions."
|
|
||||||
)
|
|
||||||
|
|
||||||
simulstreaming_kwargs = {}
|
|
||||||
for attr in ['frame_threshold', 'beams', 'decoder_type', 'audio_max_len', 'audio_min_len',
|
|
||||||
'cif_ckpt_path', 'never_fire', 'init_prompt', 'static_init_prompt',
|
|
||||||
'max_context_tokens', 'model_path']:
|
|
||||||
if hasattr(args, attr):
|
|
||||||
simulstreaming_kwargs[attr] = getattr(args, attr)
|
|
||||||
|
|
||||||
# Add segment_length from min_chunk_size
|
|
||||||
simulstreaming_kwargs['segment_length'] = getattr(args, 'min_chunk_size', 0.5)
|
|
||||||
simulstreaming_kwargs['task'] = args.task
|
|
||||||
|
|
||||||
size = args.model
|
|
||||||
t = time.time()
|
|
||||||
logger.info(f"Loading SimulStreaming {size} model for language {args.lan}...")
|
|
||||||
asr = SimulStreamingASR(
|
|
||||||
modelsize=size,
|
|
||||||
lan=args.lan,
|
|
||||||
cache_dir=getattr(args, 'model_cache_dir', None),
|
|
||||||
model_dir=getattr(args, 'model_dir', None),
|
|
||||||
**simulstreaming_kwargs
|
|
||||||
)
|
|
||||||
e = time.time()
|
|
||||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
|
||||||
else:
|
else:
|
||||||
if backend == "faster-whisper":
|
if backend == "faster-whisper":
|
||||||
asr_cls = FasterWhisperASR
|
asr_cls = FasterWhisperASR
|
||||||
@@ -109,137 +90,33 @@ def backend_factory(args):
|
|||||||
asr_cls = WhisperTimestampedASR
|
asr_cls = WhisperTimestampedASR
|
||||||
|
|
||||||
# Only for FasterWhisperASR and WhisperTimestampedASR
|
# Only for FasterWhisperASR and WhisperTimestampedASR
|
||||||
size = args.model
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
logger.info(f"Loading Whisper {size} model for language {args.lan}...")
|
logger.info(f"Loading Whisper {model_size} model for language {lan}...")
|
||||||
asr = asr_cls(
|
asr = asr_cls(
|
||||||
modelsize=size,
|
model_size=model_size,
|
||||||
lan=args.lan,
|
lan=lan,
|
||||||
cache_dir=getattr(args, 'model_cache_dir', None),
|
cache_dir=model_cache_dir,
|
||||||
model_dir=getattr(args, 'model_dir', None),
|
model_dir=model_dir,
|
||||||
)
|
)
|
||||||
e = time.time()
|
e = time.time()
|
||||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||||
|
|
||||||
# Apply common configurations
|
if task == "translate":
|
||||||
if getattr(args, "vad", False): # Checks if VAD argument is present and True
|
|
||||||
logger.info("Setting VAD filter")
|
|
||||||
asr.use_vad()
|
|
||||||
|
|
||||||
language = args.lan
|
|
||||||
if args.task == "translate":
|
|
||||||
if backend != "simulstreaming":
|
|
||||||
asr.set_translate_task()
|
|
||||||
tgt_language = "en" # Whisper translates into English
|
tgt_language = "en" # Whisper translates into English
|
||||||
else:
|
else:
|
||||||
tgt_language = language # Whisper transcribes in this language
|
tgt_language = lan # Whisper transcribes in this language
|
||||||
|
|
||||||
# Create the tokenizer
|
# Create the tokenizer
|
||||||
if args.buffer_trimming == "sentence":
|
if buffer_trimming == "sentence":
|
||||||
tokenizer = create_tokenizer(tgt_language)
|
tokenizer = create_tokenizer(tgt_language)
|
||||||
else:
|
else:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
return asr, tokenizer
|
|
||||||
|
|
||||||
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
|
||||||
if args.backend == "simulstreaming":
|
|
||||||
if not SIMULSTREAMING_ONLINE_AVAILABLE:
|
|
||||||
raise ImportError("SimulStreaming online processor is not available.")
|
|
||||||
|
|
||||||
logger.debug("Creating SimulStreaming online processor")
|
|
||||||
online = SimulStreamingOnlineProcessor(
|
|
||||||
asr,
|
|
||||||
tokenizer,
|
|
||||||
logfile=logfile,
|
|
||||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
|
||||||
confidence_validation=args.confidence_validation
|
|
||||||
)
|
|
||||||
elif args.vac:
|
|
||||||
online = VACOnlineASRProcessor(
|
|
||||||
args.min_chunk_size,
|
|
||||||
asr,
|
|
||||||
tokenizer,
|
|
||||||
logfile=logfile,
|
|
||||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
|
||||||
confidence_validation = args.confidence_validation
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
online = OnlineASRProcessor(
|
|
||||||
asr,
|
|
||||||
tokenizer,
|
|
||||||
logfile=logfile,
|
|
||||||
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
|
||||||
confidence_validation = args.confidence_validation
|
|
||||||
)
|
|
||||||
return online
|
|
||||||
|
|
||||||
def asr_factory(args, logfile=sys.stderr):
|
|
||||||
"""
|
|
||||||
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
|
||||||
"""
|
|
||||||
asr, tokenizer = backend_factory(args)
|
|
||||||
online = online_factory(args, asr, tokenizer, logfile=logfile)
|
|
||||||
return asr, online
|
|
||||||
|
|
||||||
def warmup_asr(asr, warmup_file=None, timeout=5):
|
|
||||||
"""
|
|
||||||
Warmup the ASR model by transcribing a short audio file.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
is_simulstreaming = hasattr(asr, 'warmup') and callable(getattr(asr, 'warmup'))
|
warmup_asr(asr, warmup_file)
|
||||||
|
|
||||||
if warmup_file is None:
|
asr.confidence_validation = confidence_validation
|
||||||
# Download JFK sample if not already present
|
asr.tokenizer = tokenizer
|
||||||
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
asr.buffer_trimming = buffer_trimming
|
||||||
temp_dir = tempfile.gettempdir()
|
asr.buffer_trimming_sec = buffer_trimming_sec
|
||||||
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
return asr
|
||||||
|
|
||||||
if not os.path.exists(warmup_file):
|
|
||||||
logger.debug(f"Downloading warmup file from {jfk_url}")
|
|
||||||
print(f"Downloading warmup file from {jfk_url}")
|
|
||||||
import time
|
|
||||||
import urllib.request
|
|
||||||
import urllib.error
|
|
||||||
import socket
|
|
||||||
|
|
||||||
original_timeout = socket.getdefaulttimeout()
|
|
||||||
socket.setdefaulttimeout(timeout)
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
try:
|
|
||||||
urllib.request.urlretrieve(jfk_url, warmup_file)
|
|
||||||
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
|
||||||
except (urllib.error.URLError, socket.timeout) as e:
|
|
||||||
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
|
||||||
return False
|
|
||||||
finally:
|
|
||||||
socket.setdefaulttimeout(original_timeout)
|
|
||||||
elif not warmup_file:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
|
||||||
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"Warming up {'SimulStreaming' if is_simulstreaming else 'Whisper'} with {warmup_file}")
|
|
||||||
try:
|
|
||||||
import librosa
|
|
||||||
audio, sr = librosa.load(warmup_file, sr=16000)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load audio file: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
if is_simulstreaming:
|
|
||||||
asr.warmup(audio)
|
|
||||||
else:
|
|
||||||
asr.transcribe(audio)
|
|
||||||
|
|
||||||
logger.info(f"{'SimulStreaming' if is_simulstreaming else 'Whisper'} is warmed up")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Warmup failed: {e}")
|
|
||||||
return False
|
|
||||||