Compare commits
537 Commits
regularfry
...
rework_sta
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a45ec221c | ||
|
|
3e13ee6fc3 | ||
|
|
b7d20a0ff0 | ||
|
|
c1bb9c2bde | ||
|
|
11e9def0b2 | ||
|
|
3104f40f6e | ||
|
|
e9b4ceeee5 | ||
|
|
437641fb43 | ||
|
|
bfd60b3921 | ||
|
|
1e67bf97f0 | ||
|
|
bbd4fd6cff | ||
|
|
28985962a0 | ||
|
|
a38c103fcd | ||
|
|
4d2ffb24f8 | ||
|
|
1bbbb7903c | ||
|
|
bcffdbc6b3 | ||
|
|
80b77998f9 | ||
|
|
d310f7e25f | ||
|
|
8d9be88fe6 | ||
|
|
16461052ed | ||
|
|
5491dbd824 | ||
|
|
13401ffe24 | ||
|
|
7108d2ddc5 | ||
|
|
a732e0903e | ||
|
|
0491681be4 | ||
|
|
ffe5284764 | ||
|
|
41ca17acda | ||
|
|
06b31f51eb | ||
|
|
ece02db6a3 | ||
|
|
939a7ebf8b | ||
|
|
61edb70fff | ||
|
|
4e455b8aab | ||
|
|
9434390ad3 | ||
|
|
65250db92c | ||
|
|
416dce7975 | ||
|
|
0c5365e7c6 | ||
|
|
19e9d76610 | ||
|
|
e7b05b0138 | ||
|
|
818c9c37ca | ||
|
|
714fb3b14a | ||
|
|
0af379c465 | ||
|
|
9c5bb5df19 | ||
|
|
dc6ea79036 | ||
|
|
21bbb59e31 | ||
|
|
12a69205ed | ||
|
|
1f684cdd97 | ||
|
|
3467109668 | ||
|
|
971f8473eb | ||
|
|
8434ef5efc | ||
|
|
290470dd60 | ||
|
|
425ac7b51d | ||
|
|
0382cfbeba | ||
|
|
9b1e061b32 | ||
|
|
b4abc158b9 | ||
|
|
5832d7433d | ||
|
|
3736458503 | ||
|
|
374618e050 | ||
|
|
543972ef38 | ||
|
|
73f36cc0ef | ||
|
|
a7db39d999 | ||
|
|
a153e11fe0 | ||
|
|
ca6f9246cc | ||
|
|
d080d675a8 | ||
|
|
40bff38933 | ||
|
|
2fe3ca0188 | ||
|
|
545ea15c9a | ||
|
|
8cbaeecc75 | ||
|
|
70e854b346 | ||
|
|
d55490cd27 | ||
|
|
1fa9e1f656 | ||
|
|
994f30e1ed | ||
|
|
b22478c0b4 | ||
|
|
94c34efd90 | ||
|
|
32099b9275 | ||
|
|
9fc6654a4a | ||
|
|
d24c110d55 | ||
|
|
4dd5d8bf8a | ||
|
|
cd9a32a36b | ||
|
|
6caf3e0485 | ||
|
|
93f002cafb | ||
|
|
c5e30c2c07 | ||
|
|
1c2afb8bd2 | ||
|
|
674b20d3af | ||
|
|
a5503308c5 | ||
|
|
e61afdefa3 | ||
|
|
426d70a790 | ||
|
|
b03a212fbf | ||
|
|
1833e7c921 | ||
|
|
777ec63a71 | ||
|
|
0a6e5ae9c1 | ||
|
|
ee448a37e9 | ||
|
|
9c051052b0 | ||
|
|
4d7c487614 | ||
|
|
65025cc448 | ||
|
|
bbba1d9bb7 | ||
|
|
99dc96c644 | ||
|
|
2a27d2030a | ||
|
|
cd160caaa1 | ||
|
|
d27b5eb23e | ||
|
|
f9d704a900 | ||
|
|
2f6e00f512 | ||
|
|
5aa312e437 | ||
|
|
ebaf36a8be | ||
|
|
babe93b99a | ||
|
|
a4e9f3cab7 | ||
|
|
b06866877a | ||
|
|
967cdfebc8 | ||
|
|
3c11c60126 | ||
|
|
2963e8a757 | ||
|
|
cb2d4ea88a | ||
|
|
add7ea07ee | ||
|
|
da8726b2cb | ||
|
|
3358877054 | ||
|
|
1f7798c7c1 | ||
|
|
c7b3bb5e58 | ||
|
|
f661f21675 | ||
|
|
b6164aa59b | ||
|
|
4209d7f7c0 | ||
|
|
334b338ab0 | ||
|
|
72f33be6f2 | ||
|
|
84890b8e61 | ||
|
|
c6668adcf3 | ||
|
|
a178ed5c22 | ||
|
|
7601c74c9c | ||
|
|
fad9ee4d21 | ||
|
|
d1a9913c47 | ||
|
|
e4ca2623cb | ||
|
|
9c1bf37960 | ||
|
|
f46528471b | ||
|
|
191680940b | ||
|
|
ee02afec56 | ||
|
|
a458028de2 | ||
|
|
abd8f2c269 | ||
|
|
f3ad4e39e4 | ||
|
|
e0a5cbf0e7 | ||
|
|
953697cd86 | ||
|
|
3bd2122eb4 | ||
|
|
50b0527858 | ||
|
|
b044fcdec2 | ||
|
|
b0508fcf2c | ||
|
|
ce89b0aebc | ||
|
|
d5008ed828 | ||
|
|
d467716e26 | ||
|
|
199e21b3ef | ||
|
|
1d926f2e67 | ||
|
|
4a71a391b8 | ||
|
|
d3ed4e46e2 | ||
|
|
057a1026d7 | ||
|
|
1ba171a58d | ||
|
|
1adac67155 | ||
|
|
42be1a3773 | ||
|
|
0a49fafa0d | ||
|
|
4a5d5e1f3b | ||
|
|
583a2ec2e4 | ||
|
|
19765e89e9 | ||
|
|
9895bc83bf | ||
|
|
ab98c31f16 | ||
|
|
f9c9c4188a | ||
|
|
c21d2302e7 | ||
|
|
4ed62e181d | ||
|
|
52a755a08c | ||
|
|
9a8d3cbd90 | ||
|
|
b101ce06bd | ||
|
|
c83fd179a8 | ||
|
|
5258305745 | ||
|
|
ce781831ee | ||
|
|
58297daf6d | ||
|
|
3393a08f7e | ||
|
|
5b2ddeccdb | ||
|
|
26cc1072dd | ||
|
|
12973711f6 | ||
|
|
909ac9dd41 | ||
|
|
d94a07d417 | ||
|
|
b32dd8bfc4 | ||
|
|
9feb0e597b | ||
|
|
9dab84a573 | ||
|
|
d089c7fce0 | ||
|
|
253a080df5 | ||
|
|
0c6e4b2aee | ||
|
|
e14bbde77d | ||
|
|
7496163467 | ||
|
|
696a94d1ce | ||
|
|
2699b0974c | ||
|
|
90c0250ba4 | ||
|
|
eb96153ffd | ||
|
|
47e3eb9b5b | ||
|
|
b8b07adeef | ||
|
|
d0e9e37ef6 | ||
|
|
820f92d8cb | ||
|
|
e42523af84 | ||
|
|
e2184d5e06 | ||
|
|
7fe0353260 | ||
|
|
0f2eba507e | ||
|
|
55e08474f3 | ||
|
|
28bdc52e1d | ||
|
|
e4221fa6c3 | ||
|
|
1652db9a2d | ||
|
|
601f17653a | ||
|
|
7718190fcd | ||
|
|
349c7dcb9e | ||
|
|
1c42b867cf | ||
|
|
d4771e563e | ||
|
|
b0a5fc0693 | ||
|
|
3b96fb8776 | ||
|
|
7f93c4b978 | ||
|
|
15c3df1cba | ||
|
|
7fb8e66c01 | ||
|
|
728e1f1290 | ||
|
|
87b9ed6ecd | ||
|
|
38b4ebe8ba | ||
|
|
d098af3185 | ||
|
|
4e56130a40 | ||
|
|
2bbdc70187 | ||
|
|
b678a55f63 | ||
|
|
5491964e81 | ||
|
|
b05297a96d | ||
|
|
197293e25e | ||
|
|
ba41c4ab56 | ||
|
|
bda72b8bc0 | ||
|
|
bb6b9f4cb1 | ||
|
|
e40b5a3ea0 | ||
|
|
4cfed6e98e | ||
|
|
687e3dd5e2 | ||
|
|
e4140cd299 | ||
|
|
8e056cbdf2 | ||
|
|
9dcfb38967 | ||
|
|
47b9235d70 | ||
|
|
f3cd53a4db | ||
|
|
dbdb4ea66c | ||
|
|
00424d7ca3 | ||
|
|
4b738d6f63 | ||
|
|
8a5e2adb1e | ||
|
|
f85329e112 | ||
|
|
46efbdf1d9 | ||
|
|
8885ade003 | ||
|
|
2564928d83 | ||
|
|
56114d3071 | ||
|
|
5b9977c9af | ||
|
|
12a544164f | ||
|
|
2ca1156b7e | ||
|
|
3ad3683ca7 | ||
|
|
1599bd87a0 | ||
|
|
90623400a4 | ||
|
|
64e44fb24f | ||
|
|
156b9a133f | ||
|
|
df8cb23848 | ||
|
|
9ff513093b | ||
|
|
17184e552c | ||
|
|
aad2c55d8c | ||
|
|
2f177c4a3b | ||
|
|
b362eccb23 | ||
|
|
5daaf77258 | ||
|
|
36cc4412c3 | ||
|
|
e1d4bf7e94 | ||
|
|
62bf28949e | ||
|
|
25526b3aa2 | ||
|
|
1e3fab9550 | ||
|
|
f25de6d8a4 | ||
|
|
8a175e79d8 | ||
|
|
dc37b44486 | ||
|
|
2d1df92aa7 | ||
|
|
2c1a603e38 | ||
|
|
774cee036b | ||
|
|
d22916988e | ||
|
|
5b8ad94dde | ||
|
|
f668570292 | ||
|
|
7c0768e8f3 | ||
|
|
b42d8b2692 | ||
|
|
0cd885247c | ||
|
|
8e30e8010a | ||
|
|
bfec335a5f | ||
|
|
6867041254 | ||
|
|
e165916952 | ||
|
|
8532a91c7a | ||
|
|
b01b81bad0 | ||
|
|
0f79d442ee | ||
|
|
c9f60504e3 | ||
|
|
993a83546a | ||
|
|
eabd1b199a | ||
|
|
f7644268c1 | ||
|
|
34e8fe260e | ||
|
|
debfefaf3e | ||
|
|
101ca9ef90 | ||
|
|
94bb05d53e | ||
|
|
6797b88176 | ||
|
|
46770efd6c | ||
|
|
b23ef3ec3e | ||
|
|
fa29a24abe | ||
|
|
fea3c3553c | ||
|
|
d6d65a663b | ||
|
|
083d5b2f44 | ||
|
|
8e4674b093 | ||
|
|
bc7c32100f | ||
|
|
c4150894af | ||
|
|
25bf242ce1 | ||
|
|
14cc601a5c | ||
|
|
34d5d513fa | ||
|
|
2ab3dac948 | ||
|
|
b56fcffde1 | ||
|
|
2def194893 | ||
|
|
29978da301 | ||
|
|
b708890788 | ||
|
|
3ac4c514cf | ||
|
|
3c58bfcfa2 | ||
|
|
d53b7a323a | ||
|
|
02de5993e6 | ||
|
|
d94560ef37 | ||
|
|
f62baa80b7 | ||
|
|
0b43035701 | ||
|
|
704170ccf3 | ||
|
|
09279c572a | ||
|
|
23e41f993f | ||
|
|
c791b1e125 | ||
|
|
3de2990ec4 | ||
|
|
51e6a6f6f9 | ||
|
|
f6e53b2fab | ||
|
|
5d6f08ff7a | ||
|
|
583a26da88 | ||
|
|
5b3d8969e8 | ||
|
|
40cca184c1 | ||
|
|
47ed345f9e | ||
|
|
9c9c179684 | ||
|
|
b870c12f62 | ||
|
|
cfd5905fd4 | ||
|
|
2399487e45 | ||
|
|
afd88310fd | ||
|
|
080f446b0d | ||
|
|
8bd2b36488 | ||
|
|
25fd924bf9 | ||
|
|
ff8fd0ec72 | ||
|
|
e99f53e649 | ||
|
|
e9022894b2 | ||
|
|
ccf99cecdf | ||
|
|
40e2814cd7 | ||
|
|
cd29eace3d | ||
|
|
38cb54640f | ||
|
|
81268a7ca3 | ||
|
|
33cbd24964 | ||
|
|
e966e78584 | ||
|
|
e61d1d111f | ||
|
|
c13d36b5e7 | ||
|
|
5624c1f6b7 | ||
|
|
7679370cf6 | ||
|
|
5ca65e21b7 | ||
|
|
dc02bcdbdd | ||
|
|
4f87ac3ea4 | ||
|
|
eead544977 | ||
|
|
f4a57cd810 | ||
|
|
b768b219fe | ||
|
|
2fb386f94c | ||
|
|
cb5cf39336 | ||
|
|
3024a9bdb2 | ||
|
|
7b582f3f9f | ||
|
|
8ae38a48ef | ||
|
|
fc3ffada59 | ||
|
|
e3550ef07d | ||
|
|
b502c8c81d | ||
|
|
b37d3cafb3 | ||
|
|
d304011aac | ||
|
|
597772c6c5 | ||
|
|
a656ccae72 | ||
|
|
e910873312 | ||
|
|
2a869cd509 | ||
|
|
d053bac871 | ||
|
|
e486ef8d98 | ||
|
|
0a1fb08371 | ||
|
|
ddb8860528 | ||
|
|
2e19516b3e | ||
|
|
3c7bc6f472 | ||
|
|
2d2a4967e6 | ||
|
|
7e880e039e | ||
|
|
627386a8a4 | ||
|
|
14af47e84b | ||
|
|
00eb4a0a4f | ||
|
|
2f87e592e0 | ||
|
|
56717b094f | ||
|
|
7b1c88589e | ||
|
|
72ce8d0e3f | ||
|
|
09090aa3f5 | ||
|
|
d3960ffef9 | ||
|
|
247582fb33 | ||
|
|
091d5d7bf5 | ||
|
|
9d5d6d8031 | ||
|
|
8aa3c760c7 | ||
|
|
f925ef3786 | ||
|
|
2ced4fef20 | ||
|
|
5b9b9328e0 | ||
|
|
d89622b9c2 | ||
|
|
d4096e7e11 | ||
|
|
296327071d | ||
|
|
34b707d84e | ||
|
|
f200f2cad4 | ||
|
|
8c6d39162f | ||
|
|
e3adc379ed | ||
|
|
90f24ef537 | ||
|
|
e4c84346c9 | ||
|
|
cf7944f13d | ||
|
|
d7c945dcce | ||
|
|
fa39eda923 | ||
|
|
01f02b066a | ||
|
|
a93bae69a5 | ||
|
|
f21dad559d | ||
|
|
97c0ae6154 | ||
|
|
09d40a7de8 | ||
|
|
2608abf0f3 | ||
|
|
58eba2a1f6 | ||
|
|
450c93fef8 | ||
|
|
1ffa2fa224 | ||
|
|
dc24366580 | ||
|
|
6121083549 | ||
|
|
0ecac75455 | ||
|
|
525abcbca7 | ||
|
|
365e7c882f | ||
|
|
84b09bb2cc | ||
|
|
4601e97221 | ||
|
|
15089c80fd | ||
|
|
788fe1c676 | ||
|
|
d623578d95 | ||
|
|
149d2ee44c | ||
|
|
adaca751ce | ||
|
|
eb989038bd | ||
|
|
1f6119e405 | ||
|
|
f7f1f259c1 | ||
|
|
b82cc3b613 | ||
|
|
46f7f9cbd1 | ||
|
|
48c111f494 | ||
|
|
54628274d6 | ||
|
|
0d874fb515 | ||
|
|
4d1aa4421a | ||
|
|
f4d98e2c8c | ||
|
|
15205f31d1 | ||
|
|
b1f7034577 | ||
|
|
23dee02d56 | ||
|
|
efd80095a7 | ||
|
|
f4d3df3d87 | ||
|
|
9c7d429e15 | ||
|
|
611d33cba5 | ||
|
|
ab7c22d3e3 | ||
|
|
870a779666 | ||
|
|
c3d72cae7c | ||
|
|
4622fe7aff | ||
|
|
8ee1488c08 | ||
|
|
77d43885a3 | ||
|
|
04170153e0 | ||
|
|
baddf0284b | ||
|
|
6e0f1dda25 | ||
|
|
c66794e1f5 | ||
|
|
f0eaffacd3 | ||
|
|
69a2ed6bfb | ||
|
|
25eb276794 | ||
|
|
9f262813ec | ||
|
|
4293580581 | ||
|
|
42d2784c20 | ||
|
|
7fad0a3ee2 | ||
|
|
27d2db77f7 | ||
|
|
fba37eba0a | ||
|
|
5523b51fd7 | ||
|
|
9bdb92e923 | ||
|
|
b51c8427f4 | ||
|
|
977436622a | ||
|
|
ce56264241 | ||
|
|
9cbac96c44 | ||
|
|
3f30d3de6e | ||
|
|
f884d1162d | ||
|
|
6ee91c3c93 | ||
|
|
f52a5ae3c2 | ||
|
|
0ff6067f37 | ||
|
|
da6c8d25e4 | ||
|
|
aa0ba598f0 | ||
|
|
b7a2d23a18 | ||
|
|
58e48bb717 | ||
|
|
6a04ddbed2 | ||
|
|
aa4d2599cc | ||
|
|
5fdb08edae | ||
|
|
4cb3660666 | ||
|
|
122368bff3 | ||
|
|
0d833eaea2 | ||
|
|
c960d1571d | ||
|
|
1aa1b9ea99 | ||
|
|
99019f1dd7 | ||
|
|
1cea20a42d | ||
|
|
50bbd26517 | ||
|
|
cf5d1cf013 | ||
|
|
0553b75415 | ||
|
|
baa01728be | ||
|
|
8dcebd9329 | ||
|
|
bfe973a0d2 | ||
|
|
87cab7c280 | ||
|
|
bee27c68e6 | ||
|
|
aa4480b138 | ||
|
|
cc92e97e17 | ||
|
|
8c6c0104a3 | ||
|
|
494b6e3ca9 | ||
|
|
d045137ba8 | ||
|
|
54a37fbcb6 | ||
|
|
104f7bde03 | ||
|
|
e6648e4f46 | ||
|
|
863242f107 | ||
|
|
d48895c343 | ||
|
|
8cfd8d85a3 | ||
|
|
e1b0e146a5 | ||
|
|
e3dc524783 | ||
|
|
2de090023c | ||
|
|
e25ad4fcd7 | ||
|
|
63870987c0 | ||
|
|
7eeb73f4d4 | ||
|
|
d665f9a96e | ||
|
|
827425bb91 | ||
|
|
4a89935ee5 | ||
|
|
4c17b56041 | ||
|
|
52da12120c | ||
|
|
7edc534f8a | ||
|
|
14c2bbef87 | ||
|
|
36bf3a32d4 | ||
|
|
2ec2266929 | ||
|
|
f3907703ed | ||
|
|
13fd21a201 | ||
|
|
84a999570a | ||
|
|
884958127f | ||
|
|
726fa574a2 | ||
|
|
333eea4b76 | ||
|
|
8d60fd3bf6 | ||
|
|
9c15262015 | ||
|
|
7bca7a2b8e | ||
|
|
264b8a32c2 | ||
|
|
706b7f847e | ||
|
|
c8123344c6 | ||
|
|
6b968c6e29 | ||
|
|
6fa008080a | ||
|
|
d543411bbd | ||
|
|
b2e4e9f727 | ||
|
|
324dee03e7 | ||
|
|
fe4207edca | ||
|
|
ea2a9ca2e6 | ||
|
|
c8c786af4f | ||
|
|
3fad8133b4 | ||
|
|
9556d07484 |
30
.gitignore
vendored
@@ -55,22 +55,6 @@ coverage.xml
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
@@ -127,3 +111,17 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
*.wav
|
||||
run_*.sh
|
||||
|
||||
# Downloaded models
|
||||
*.pt
|
||||
|
||||
# Debug & testing
|
||||
test_*.py
|
||||
launch.json
|
||||
.DS_Store
|
||||
test/*
|
||||
nllb-200-distilled-600M-ctranslate2/*
|
||||
*.mp3
|
||||
46
CONTRIBUTING.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Contributing
|
||||
|
||||
Thank you for considering contributing ! We appreciate your time and effort to help make this project better.
|
||||
|
||||
## Before You Start
|
||||
|
||||
1. **Search for Existing Issues or Discussions:**
|
||||
- Before opening a new issue or discussion, please check if there's already an existing one related to your topic. This helps avoid duplicates and keeps discussions centralized.
|
||||
|
||||
2. **Discuss Your Contribution:**
|
||||
- If you plan to make a significant change, it's advisable to discuss it in an issue first. This ensures that your contribution aligns with the project's goals and avoids duplicated efforts.
|
||||
|
||||
3. **General questions about whisper streaming web:**
|
||||
- For general questions about whisper streaming web, use the discussion space on GitHub. This helps in fostering a collaborative environment and encourages knowledge-sharing.
|
||||
|
||||
## Opening Issues
|
||||
|
||||
If you encounter a problem with WhisperLiveKit or want to suggest an improvement, please follow these guidelines when opening an issue:
|
||||
|
||||
- **Bug Reports:**
|
||||
- Clearly describe the error. **Please indicate the parameters you use, especially the model(s)**
|
||||
- Provide a minimal, reproducible example that demonstrates the issue.
|
||||
|
||||
- **Feature Requests:**
|
||||
- Clearly outline the new feature you are proposing.
|
||||
- Explain how it would benefit the project.
|
||||
|
||||
## Opening Pull Requests
|
||||
|
||||
We welcome and appreciate contributions! To ensure a smooth review process, please follow these guidelines when opening a pull request:
|
||||
|
||||
- **Commit Messages:**
|
||||
- Write clear and concise commit messages, explaining the purpose of each change.
|
||||
|
||||
- **Documentation:**
|
||||
- Update documentation when introducing new features or making changes that impact existing functionality.
|
||||
|
||||
- **Tests:**
|
||||
- If applicable, add or update tests to cover your changes.
|
||||
|
||||
- **Discuss Before Major Changes:**
|
||||
- If your PR includes significant changes, discuss it in an issue first.
|
||||
|
||||
## Thank You
|
||||
|
||||
Your contributions make WhisperLiveKit better for everyone. Thank you for your time and dedication!
|
||||
91
DEV_NOTES.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# 1. Simulstreaming: Decouple the encoder for faster inference
|
||||
|
||||
Simulstreaming encoder time (whisperlivekit/simul_whisper/simul_whisper.py l. 397) experimentations :
|
||||
|
||||
On macOS Apple Silicon M4 :
|
||||
|
||||
| Encoder | base.en | small |
|
||||
|--------|---------|-------|
|
||||
| WHISPER (no modification) | 0.35s | 1.09s |
|
||||
| FASTER_WHISPER | 0.4s | 1.20s |
|
||||
| MLX_WHISPER | 0.07s | 0.20s |
|
||||
|
||||
Memory saved by only loading encoder for optimized framework:
|
||||
|
||||
For tiny.en, mlx whisper:
|
||||
Sizes MLX whisper:
|
||||
Decoder weights: 59110771 bytes
|
||||
Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
# 2. Translation: Faster model for each system
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
Testing on MacBook M3 with NLLB-200-distilled-600M model:
|
||||
|
||||
### Standard Transformers vs CTranslate2
|
||||
|
||||
| Test Text | Standard Inference Time | CTranslate2 Inference Time | Speedup |
|
||||
|-----------|-------------------------|---------------------------|---------|
|
||||
| UN Chief says there is no military solution in Syria | 0.9395s | 2.0472s | 0.5x |
|
||||
| The rapid advancement of AI technology is transforming various industries | 0.7171s | 1.7516s | 0.4x |
|
||||
| Climate change poses a significant threat to global ecosystems | 0.8533s | 1.8323s | 0.5x |
|
||||
| International cooperation is essential for addressing global challenges | 0.7209s | 1.3575s | 0.5x |
|
||||
| The development of renewable energy sources is crucial for a sustainable future | 0.8760s | 1.5589s | 0.6x |
|
||||
|
||||
**Results:**
|
||||
- Total Standard time: 4.1068s
|
||||
- Total CTranslate2 time: 8.5476s
|
||||
- CTranslate2 is slower on this system --> Use Transformers, and ideally we would have an mlx implementation.
|
||||
|
||||
|
||||
# 3. SortFormer Diarization: 4-to-2 Speaker Constraint Algorithm
|
||||
|
||||
Transform a diarization model that predicts up to 4 speakers into one that predicts up to 2 speakers by mapping the output predictions.
|
||||
|
||||
## Problem Statement
|
||||
- Input: `self.total_preds` with shape `(x, x, 4)` - predictions for 4 speakers
|
||||
- Output: Constrained predictions with shape `(x, x, 2)` - predictions for 2 speakers
|
||||
|
||||
#
|
||||
### Initial Setup
|
||||
For each time step `i`, we have a ranking of 4 speaker predictions (1-4). When only 2 speakers are present, the model will have close predictions for the 2 active speaker positions.
|
||||
|
||||
Instead of `np.argmax(preds_np, axis=1)`, we take the top 2 predictions and build a dynamic 4→2 mapping that can evolve over time.
|
||||
|
||||
### Algorithm
|
||||
|
||||
```python
|
||||
top_2_speakers = np.argsort(preds_np, axis=1)[:, -2:]
|
||||
```
|
||||
|
||||
- `DS_a_{i}`: Top detected speaker for prediction i
|
||||
- `DS_b_{i}`: Second detected speaker for prediction i
|
||||
- `AS_{i}`: Attributed speaker for prediction i
|
||||
- `GTS_A`: Ground truth speaker A
|
||||
- `GTS_B`: Ground truth speaker B
|
||||
- `DIST(a, b)`: Distance between detected speakers a and b
|
||||
|
||||
3. **Attribution Logic**
|
||||
|
||||
```
|
||||
AS_0 ← A
|
||||
|
||||
AS_1 ← B
|
||||
|
||||
IF DIST(DS_a_0, DS_a_1) < DIST(DS_a_0, DS_a_2) AND
|
||||
DIST(DS_a_0, DS_a_1) < DIST(DS_a_1, DS_a_2):
|
||||
# Likely that DS_a_0 = DS_a_1 (same speaker)
|
||||
AS_1 ← A
|
||||
AS_2 ← B
|
||||
|
||||
ELIF DIST(DS_a_0, DS_a_2) < DIST(DS_a_0, DS_a_1) AND
|
||||
DIST(DS_a_0, DS_a_2) < DIST(DS_a_1, DS_a_2):
|
||||
AS_2 ← A
|
||||
|
||||
ELSE:
|
||||
AS_2 ← B
|
||||
|
||||
to finish
|
||||
```
|
||||
83
Dockerfile
Normal file
@@ -0,0 +1,83 @@
|
||||
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ARG EXTRAS
|
||||
ARG HF_PRECACHE_DIR
|
||||
ARG HF_TKN_FILE
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
ffmpeg \
|
||||
git \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
ca-certificates && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# timeout/retries for large torch wheels
|
||||
RUN pip3 install --upgrade pip setuptools wheel && \
|
||||
pip3 --disable-pip-version-check install --timeout=120 --retries=5 \
|
||||
--index-url https://download.pytorch.org/whl/cu129 \
|
||||
torch torchaudio \
|
||||
|| (echo "Initial install failed — retrying with extended timeout..." && \
|
||||
pip3 --disable-pip-version-check install --timeout=300 --retries=3 \
|
||||
--index-url https://download.pytorch.org/whl/cu129 \
|
||||
torch torchvision torchaudio)
|
||||
|
||||
COPY . .
|
||||
|
||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||
RUN if [ -n "$EXTRAS" ]; then \
|
||||
echo "Installing with extras: [$EXTRAS]"; \
|
||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||
else \
|
||||
echo "Installing base package only"; \
|
||||
pip install --no-cache-dir whisperlivekit; \
|
||||
fi
|
||||
|
||||
# In-container caching for Hugging Face models by:
|
||||
# A) Make the cache directory persistent via an anonymous volume.
|
||||
# Note: This only persists for a single, named container. This is
|
||||
# only for convenience at de/test stage.
|
||||
# For prod, it is better to use a named volume via host mount/k8s.
|
||||
VOLUME ["/root/.cache/huggingface/hub"]
|
||||
|
||||
|
||||
# or
|
||||
# B) Conditionally copy a local pre-cache from the build context to the
|
||||
# container's cache via the HF_PRECACHE_DIR build-arg.
|
||||
# WARNING: This will copy ALL files in the pre-cache location.
|
||||
|
||||
# Conditionally copy a cache directory if provided
|
||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
||||
mkdir -p /root/.cache/huggingface/hub && \
|
||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
||||
else \
|
||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||
fi
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided. Useful for Diart backend (pyannote audio models)
|
||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||
mkdir -p /root/.cache/huggingface && \
|
||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
||||
else \
|
||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||
fi
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
CMD ["--model", "medium"]
|
||||
61
Dockerfile.cpu
Normal file
@@ -0,0 +1,61 @@
|
||||
FROM python:3.13-slim
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ARG EXTRAS
|
||||
ARG HF_PRECACHE_DIR
|
||||
ARG HF_TKN_FILE
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
git \
|
||||
build-essential \
|
||||
python3-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install CPU-only PyTorch
|
||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
COPY . .
|
||||
|
||||
# Install WhisperLiveKit directly, allowing for optional dependencies
|
||||
RUN if [ -n "$EXTRAS" ]; then \
|
||||
echo "Installing with extras: [$EXTRAS]"; \
|
||||
pip install --no-cache-dir whisperlivekit[$EXTRAS]; \
|
||||
else \
|
||||
echo "Installing base package only"; \
|
||||
pip install --no-cache-dir whisperlivekit; \
|
||||
fi
|
||||
|
||||
# Enable in-container caching for Hugging Face models
|
||||
VOLUME ["/root/.cache/huggingface/hub"]
|
||||
|
||||
# Conditionally copy a local pre-cache from the build context
|
||||
RUN if [ -n "$HF_PRECACHE_DIR" ]; then \
|
||||
echo "Copying Hugging Face cache from $HF_PRECACHE_DIR"; \
|
||||
mkdir -p /root/.cache/huggingface/hub && \
|
||||
cp -r $HF_PRECACHE_DIR/* /root/.cache/huggingface/hub; \
|
||||
else \
|
||||
echo "No local Hugging Face cache specified, skipping copy"; \
|
||||
fi
|
||||
|
||||
# Conditionally copy a Hugging Face token if provided
|
||||
RUN if [ -n "$HF_TKN_FILE" ]; then \
|
||||
echo "Copying Hugging Face token from $HF_TKN_FILE"; \
|
||||
mkdir -p /root/.cache/huggingface && \
|
||||
cp $HF_TKN_FILE /root/.cache/huggingface/token; \
|
||||
else \
|
||||
echo "No Hugging Face token file specified, skipping token setup"; \
|
||||
fi
|
||||
|
||||
# Expose port for the transcription server
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT ["whisperlivekit-server", "--host", "0.0.0.0"]
|
||||
|
||||
# Default args - you might want to use a smaller model for CPU
|
||||
CMD ["--model", "tiny"]
|
||||
223
LICENSE
@@ -1,21 +1,210 @@
|
||||
MIT License
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
Copyright (c) 2023 ÚFAL
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
1. Definitions.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2025 Quentin Fuxa
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
---
|
||||
|
||||
## Based on:
|
||||
- **SimulWhisper** by Speech and Audio Technology LAB of Tsinghua University – Apache-2.0 – https://github.com/ufal/SimulStreaming
|
||||
- **SimulStreaming** by ÚFAL – MIT License – https://github.com/ufal/SimulStreaming
|
||||
- **NeMo** by NVidia - Apache-2.0 - https://github.com/NVIDIA-NeMo/NeMo
|
||||
- **whisper_streaming** by ÚFAL – MIT License – https://github.com/ufal/whisper_streaming.
|
||||
- **silero-vad** by Snakers4 – MIT License – https://github.com/snakers4/silero-vad.
|
||||
- **Diart** by juanmc2005 – MIT License – https://github.com/juanmc2005/diart.
|
||||
|
||||
452
README.md
@@ -1,247 +1,269 @@
|
||||
# whisper_streaming
|
||||
Whisper realtime streaming for long speech-to-text transcription and translation
|
||||
<h1 align="center">WhisperLiveKit</h1>
|
||||
|
||||
**Turning Whisper into Real-Time Transcription System**
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||
</p>
|
||||
|
||||
Demonstration paper, by [Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek), [Raj Dabre](https://prajdabre.github.io/), [Ondřej Bojar](https://ufal.mff.cuni.cz/ondrej-bojar), 2023
|
||||
<p align="center"><b>Real-time, Fully Local Speech-to-Text with Speaker Identification</b></p>
|
||||
|
||||
Abstract: Whisper is one of the recent state-of-the-art multilingual speech recognition and translation models, however, it is not designed for real-time transcription. In this paper, we build on top of Whisper and create Whisper-Streaming, an implementation of real-time speech transcription and translation of Whisper-like models. Whisper-Streaming uses local agreement policy with self-adaptive latency to enable streaming transcription. We show that Whisper-Streaming achieves high quality and 3.3 seconds latency on unsegmented long-form speech transcription test set, and we demonstrate its robustness and practical usability as a component in live transcription service at a multilingual conference.
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.15-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-Apache 2.0-dark_green"></a>
|
||||
</p>
|
||||
|
||||
|
||||
[Paper PDF](https://aclanthology.org/2023.ijcnlp-demo.3.pdf), [Demo video](https://player.vimeo.com/video/840442741)
|
||||
Real-time transcription directly to your browser, with a ready-to-use backend+server and a simple frontend.
|
||||
|
||||
[Slides](http://ufallab.ms.mff.cuni.cz/~machacek/pre-prints/AACL23-2.11.2023-Turning-Whisper-oral.pdf) -- 15 minutes oral presentation at IJCNLP-AACL 2023
|
||||
#### Powered by Leading Research:
|
||||
|
||||
Please, cite us. [ACL Anthology](https://aclanthology.org/2023.ijcnlp-demo.3/), [Bibtex citation](https://aclanthology.org/2023.ijcnlp-demo.3.bib):
|
||||
- Simul-[Whisper](https://github.com/backspacetg/simul_whisper)/[Streaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - Ultra-low latency transcription using [AlignAtt policy](https://arxiv.org/pdf/2305.11408)
|
||||
- [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting) (2025), based on [distilled](https://huggingface.co/entai2965/nllb-200-distilled-600M-ctranslate2) [NLLB](https://arxiv.org/abs/2207.04672) (2022, 2024) - Simulatenous translation from & to 200 languages.
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - Low latency transcription using [LocalAgreement policy](https://www.isca-archive.org/interspeech_2020/liu20s_interspeech.pdf)
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - Advanced real-time speaker diarization
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - Real-time speaker diarization
|
||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - Enterprise-grade Voice Activity Detection
|
||||
|
||||
|
||||
> **Why not just run a simple Whisper model on every audio batch?** Whisper is designed for complete utterances, not real-time chunks. Processing small segments loses context, cuts off words mid-syllable, and produces poor transcription. WhisperLiveKit uses state-of-the-art simultaneous speech research for intelligent buffering and incremental processing.
|
||||
|
||||
|
||||
### Architecture
|
||||
|
||||
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||
|
||||
*The backend supports multiple concurrent users. Voice Activity Detection reduces overhead when no voice is detected.*
|
||||
|
||||
### Installation & Quick Start
|
||||
|
||||
```bash
|
||||
pip install whisperlivekit
|
||||
```
|
||||
@inproceedings{machacek-etal-2023-turning,
|
||||
title = "Turning Whisper into Real-Time Transcription System",
|
||||
author = "Mach{\'a}{\v{c}}ek, Dominik and
|
||||
Dabre, Raj and
|
||||
Bojar, Ond{\v{r}}ej",
|
||||
editor = "Saha, Sriparna and
|
||||
Sujaini, Herry",
|
||||
booktitle = "Proceedings of the 13th International Joint Conference on Natural Language Processing and the 3rd Conference of the Asia-Pacific Chapter of the Association for Computational Linguistics: System Demonstrations",
|
||||
month = nov,
|
||||
year = "2023",
|
||||
address = "Bali, Indonesia",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://aclanthology.org/2023.ijcnlp-demo.3",
|
||||
pages = "17--24",
|
||||
}
|
||||
> You can also clone the repo and `pip install -e .` for the latest version.
|
||||
|
||||
#### Quick Start
|
||||
1. **Start the transcription server:**
|
||||
```bash
|
||||
wlk --model base --language en
|
||||
```
|
||||
|
||||
2. **Open your browser** and navigate to `http://localhost:8000`. Start speaking and watch your words appear in real-time!
|
||||
|
||||
|
||||
> - See [tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) for the list of all available languages.
|
||||
> - For HTTPS requirements, see the **Parameters** section for SSL configuration options.
|
||||
> - The CLI entry point is exposed as both `wlk` and `whisperlivekit-server`; they are equivalent.
|
||||
|
||||
#### Use it to capture audio from web pages.
|
||||
|
||||
Go to `chrome-extension` for instructions.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="600">
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
#### Optional Dependencies
|
||||
|
||||
| Optional | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Windows/Linux optimizations** | `faster-whisper` |
|
||||
| **Apple Silicon optimizations** | `mlx-whisper` |
|
||||
| **Translation** | `nllw` |
|
||||
| **Speaker diarization** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| OpenAI API | `openai` |
|
||||
| *[Not recommanded]* Speaker diarization with Diart | `diart` |
|
||||
|
||||
See **Parameters & Configuration** below on how to use them.
|
||||
|
||||
|
||||
|
||||
### Usage Examples
|
||||
|
||||
**Command-line Interface**: Start the transcription server with various options:
|
||||
|
||||
```bash
|
||||
# Large model and translate from french to danish
|
||||
wlk --model large-v3 --language fr --target-language da
|
||||
|
||||
# Diarization and server listening on */80
|
||||
wlk --host 0.0.0.0 --port 80 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
1) ``pip install librosa soundfile`` -- audio processing library
|
||||
|
||||
2) Whisper backend.
|
||||
|
||||
Several alternative backends are integrated. The most recommended one is [faster-whisper](https://github.com/guillaumekln/faster-whisper) with GPU support. Follow their instructions for NVIDIA libraries -- we succeeded with CUDNN 8.5.0 and CUDA 11.7. Install with `pip install faster-whisper`.
|
||||
|
||||
Alternative, less restrictive, but slower backend is [whisper-timestamped](https://github.com/linto-ai/whisper-timestamped): `pip install git+https://github.com/linto-ai/whisper-timestamped`
|
||||
|
||||
Thirdly, it's also possible to run this software from the [OpenAI Whisper API](https://platform.openai.com/docs/api-reference/audio/createTranscription). This solution is fast and requires no GPU, just a small VM will suffice, but you will need to pay OpenAI for api access. Also note that, since each audio fragment is processed multiple times, the [price](https://openai.com/pricing) will be higher than obvious from the pricing page, so keep an eye on costs while using. Setting a higher chunk-size will reduce costs significantly.
|
||||
Install with: `pip install openai`
|
||||
|
||||
For running with the openai-api backend, make sure that your [OpenAI api key](https://platform.openai.com/api-keys) is set in the `OPENAI_API_KEY` environment variable. For example, before running, do: `export OPENAI_API_KEY=sk-xxx` with *sk-xxx* replaced with your api key.
|
||||
|
||||
The backend is loaded only when chosen. The unused one does not have to be installed.
|
||||
|
||||
3) Optional, not recommended: sentence segmenter (aka sentence tokenizer)
|
||||
|
||||
Two buffer trimming options are integrated and evaluated. They have impact on
|
||||
the quality and latency. The default "segment" option performs better according
|
||||
to our tests and does not require any sentence segmentation installed.
|
||||
|
||||
The other option, "sentence" -- trimming at the end of confirmed sentences,
|
||||
requires sentence segmenter installed. It splits punctuated text to sentences by full
|
||||
stops, avoiding the dots that are not full stops. The segmenters are language
|
||||
specific. The unused one does not have to be installed. We integrate the
|
||||
following segmenters, but suggestions for better alternatives are welcome.
|
||||
|
||||
- `pip install opus-fast-mosestokenizer` for the languages with codes `as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh`
|
||||
|
||||
- `pip install tokenize_uk` for Ukrainian -- `uk`
|
||||
|
||||
- for other languages, we integrate a good performing multi-lingual model of `wtpslit`. It requires `pip install torch wtpsplit`, and its neural model `wtp-canine-s-12l-no-adapters`. It is downloaded to the default huggingface cache during the first use.
|
||||
|
||||
- we did not find a segmenter for languages `as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt` that are supported by Whisper and not by wtpsplit. The default fallback option for them is wtpsplit with unspecified language. Alternative suggestions welcome.
|
||||
|
||||
In case of installation issues of opus-fast-mosestokenizer, especially on Windows and Mac, we recommend using only the "segment" option that does not require it.
|
||||
|
||||
## Usage
|
||||
|
||||
### Real-time simulation from audio file
|
||||
|
||||
```
|
||||
usage: whisper_online.py [-h] [--min-chunk-size MIN_CHUNK_SIZE] [--model {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large}] [--model_cache_dir MODEL_CACHE_DIR] [--model_dir MODEL_DIR] [--lan LAN] [--task {transcribe,translate}]
|
||||
[--backend {faster-whisper,whisper_timestamped,openai-api}] [--vad] [--buffer_trimming {sentence,segment}] [--buffer_trimming_sec BUFFER_TRIMMING_SEC] [--start_at START_AT] [--offline] [--comp_unaware]
|
||||
audio_path
|
||||
|
||||
positional arguments:
|
||||
audio_path Filename of 16kHz mono channel wav, on which live streaming is simulated.
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--min-chunk-size MIN_CHUNK_SIZE
|
||||
Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.
|
||||
--model {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large}
|
||||
Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.
|
||||
--model_cache_dir MODEL_CACHE_DIR
|
||||
Overriding the default model cache dir where models downloaded from the hub are saved
|
||||
--model_dir MODEL_DIR
|
||||
Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.
|
||||
--lan LAN, --language LAN
|
||||
Source language code, e.g. en,de,cs, or 'auto' for language detection.
|
||||
--task {transcribe,translate}
|
||||
Transcribe or translate.
|
||||
--backend {faster-whisper,whisper_timestamped,openai-api}
|
||||
Load only this backend for Whisper processing.
|
||||
--vad Use VAD = voice activity detection, with the default parameters.
|
||||
--buffer_trimming {sentence,segment}
|
||||
Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.
|
||||
--buffer_trimming_sec BUFFER_TRIMMING_SEC
|
||||
Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.
|
||||
--start_at START_AT Start processing audio at this time.
|
||||
--offline Offline mode.
|
||||
--comp_unaware Computationally unaware simulation.
|
||||
```
|
||||
|
||||
Example:
|
||||
|
||||
It simulates realtime processing from a pre-recorded mono 16k wav file.
|
||||
|
||||
```
|
||||
python3 whisper_online.py en-demo16.wav --language en --min-chunk-size 1 > out.txt
|
||||
```
|
||||
|
||||
Simulation modes:
|
||||
|
||||
- default mode, no special option: real-time simulation from file, computationally aware. The chunk size is `MIN_CHUNK_SIZE` or larger, if more audio arrived during last update computation.
|
||||
|
||||
- `--comp_unaware` option: computationally unaware simulation. It means that the timer that counts the emission times "stops" when the model is computing. The chunk size is always `MIN_CHUNK_SIZE`. The latency is caused only by the model being unable to confirm the output, e.g. because of language ambiguity etc., and not because of slow hardware or suboptimal implementation. We implement this feature for finding the lower bound for latency.
|
||||
|
||||
- `--start_at START_AT`: Start processing audio at this time. The first update receives the whole audio by `START_AT`. It is useful for debugging, e.g. when we observe a bug in a specific time in audio file, and want to reproduce it quickly, without long waiting.
|
||||
|
||||
- `--offline` option: It processes the whole audio file at once, in offline mode. We implement it to find out the lowest possible WER on given audio file.
|
||||
|
||||
|
||||
|
||||
### Output format
|
||||
|
||||
```
|
||||
2691.4399 300 1380 Chairman, thank you.
|
||||
6914.5501 1940 4940 If the debate today had a
|
||||
9019.0277 5160 7160 the subject the situation in
|
||||
10065.1274 7180 7480 Gaza
|
||||
11058.3558 7480 9460 Strip, I might
|
||||
12224.3731 9460 9760 have
|
||||
13555.1929 9760 11060 joined Mrs.
|
||||
14928.5479 11140 12240 De Kaiser and all the
|
||||
16588.0787 12240 12560 other
|
||||
18324.9285 12560 14420 colleagues across the
|
||||
```
|
||||
|
||||
[See description here](https://github.com/ufal/whisper_streaming/blob/d915d790a62d7be4e7392dde1480e7981eb142ae/whisper_online.py#L361)
|
||||
|
||||
### As a module
|
||||
|
||||
TL;DR: use OnlineASRProcessor object and its methods insert_audio_chunk and process_iter.
|
||||
|
||||
The code whisper_online.py is nicely commented, read it as the full documentation.
|
||||
|
||||
|
||||
This pseudocode describes the interface that we suggest for your implementation. You can implement any features that you need for your application.
|
||||
**Python API Integration**: Check [basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) for a more complete example of how to use the functions and classes.
|
||||
|
||||
```python
|
||||
from whisper_online import *
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
src_lan = "en" # source language
|
||||
tgt_lan = "en" # target language -- same as source for ASR, "en" if translate task is used
|
||||
transcription_engine = None
|
||||
|
||||
asr = FasterWhisperASR(lan, "large-v2") # loads and wraps Whisper model
|
||||
# set options:
|
||||
# asr.set_translate_task() # it will translate from lan into English
|
||||
# asr.use_vad() # set using VAD
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||
yield
|
||||
|
||||
online = OnlineASRProcessor(asr) # create processing object with default buffer trimming option
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
while audio_has_not_ended: # processing loop:
|
||||
a = # receive new audio chunk (and e.g. wait for min_chunk_size seconds first, ...)
|
||||
online.insert_audio_chunk(a)
|
||||
o = online.process_iter()
|
||||
print(o) # do something with current partial output
|
||||
# at the end of this audio processing
|
||||
o = online.finish()
|
||||
print(o) # do something with the last output
|
||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global transcription_engine
|
||||
|
||||
online.init() # refresh if you're going to re-use the object for the next audio
|
||||
# Create a new AudioProcessor for each connection, passing the shared engine
|
||||
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
await websocket.accept()
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
```
|
||||
|
||||
### Server -- real-time from mic
|
||||
**Frontend Implementation**: The package includes an HTML/JavaScript implementation [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html). You can also import it using `from whisperlivekit import get_inline_ui_html` & `page = get_inline_ui_html()`
|
||||
|
||||
`whisper_online_server.py` has the same model options as `whisper_online.py`, plus `--host` and `--port` of the TCP connection and the `--warmup-file`. See the help message (`-h` option).
|
||||
|
||||
Client example:
|
||||
## Parameters & Configuration
|
||||
|
||||
```
|
||||
arecord -f S16_LE -c1 -r 16000 -t raw -D default | nc localhost 43001
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisper model size. List and recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/available_models.md) | `small` |
|
||||
| `--model-path` | Local .pt file/directory **or** Hugging Face repo ID containing the Whisper model. Overrides `--model`. Recommandations [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/models_compatible_formats.md) | `None` |
|
||||
| `--language` | List [here](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py). If you use `auto`, the model attempts to detect the language automatically, but it tends to bias towards English. | `auto` |
|
||||
| `--target-language` | If sets, translates using [NLLW](https://github.com/QuentinFuxa/NoLanguageLeftWaiting). [200 languages available](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/docs/supported_languages.md). If you want to translate to english, you can also use `--direct-english-translation`. The STT model will try to directly output the translation. | `None` |
|
||||
| `--diarization` | Enable speaker identification | `False` |
|
||||
| `--backend-policy` | Streaming strategy: `1`/`simulstreaming` uses AlignAtt SimulStreaming, `2`/`localagreement` uses the LocalAgreement policy | `simulstreaming` |
|
||||
| `--backend` | Whisper implementation selector. `auto` picks MLX on macOS (if installed), otherwise Faster-Whisper, otherwise vanilla Whisper. You can also force `mlx-whisper`, `faster-whisper`, `whisper`, or `openai-api` (LocalAgreement only) | `auto` |
|
||||
| `--no-vac` | Disable Voice Activity Controller | `False` |
|
||||
| `--no-vad` | Disable Voice Activity Detection | `False` |
|
||||
| `--warmup-file` | Audio file path for model warmup | `jfk.wav` |
|
||||
| `--host` | Server host address | `localhost` |
|
||||
| `--port` | Server port | `8000` |
|
||||
| `--ssl-certfile` | Path to the SSL certificate file (for HTTPS support) | `None` |
|
||||
| `--ssl-keyfile` | Path to the SSL private key file (for HTTPS support) | `None` |
|
||||
| `--forwarded-allow-ips` | Ip or Ips allowed to reverse proxy the whisperlivekit-server. Supported types are IP Addresses (e.g. 127.0.0.1), IP Networks (e.g. 10.100.0.0/16), or Literals (e.g. /path/to/socket.sock) | `None` |
|
||||
| `--pcm-input` | raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder | `False` |
|
||||
|
||||
| Translation options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nllb-backend` | `transformers` or `ctranslate2` | `ctranslate2` |
|
||||
| `--nllb-size` | `600M` or `1.3B` | `600M` |
|
||||
|
||||
| Diarization options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization-backend` | `diart` or `sortformer` | `sortformer` |
|
||||
| `--disable-punctuation-split` | Disable punctuation based splits. See #214 | `False` |
|
||||
| `--segmentation-model` | Hugging Face model ID for Diart segmentation model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Hugging Face model ID for Diart embedding model. [Available models](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
| SimulStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--disable-fast-encoder` | Disable Faster Whisper or MLX Whisper backends for the encoder (if installed). Inference can be slower but helpful when GPU memory is limited | `False` |
|
||||
| `--custom-alignment-heads` | Use your own alignment heads, useful when `--model-dir` is used. Use `scripts/determine_alignment_heads.py` to extract them. <img src="scripts/alignment_heads.png" alt="WhisperLiveKit Demo" width="300">
|
||||
| `None` |
|
||||
| `--frame-threshold` | AlignAtt frame threshold (lower = faster, higher = more accurate) | `25` |
|
||||
| `--beams` | Number of beams for beam search (1 = greedy decoding) | `1` |
|
||||
| `--decoder` | Force decoder type (`beam` or `greedy`) | `auto` |
|
||||
| `--audio-max-len` | Maximum audio buffer length (seconds) | `30.0` |
|
||||
| `--audio-min-len` | Minimum audio length to process (seconds) | `0.0` |
|
||||
| `--cif-ckpt-path` | Path to CIF model for word boundary detection | `None` |
|
||||
| `--never-fire` | Never truncate incomplete words | `False` |
|
||||
| `--init-prompt` | Initial prompt for the model | `None` |
|
||||
| `--static-init-prompt` | Static prompt that doesn't scroll | `None` |
|
||||
| `--max-context-tokens` | Maximum context tokens | `None` |
|
||||
| `--preload-model-count` | Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent users) | `1` |
|
||||
|
||||
|
||||
|
||||
| WhisperStreaming backend options | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | Use confidence scores for faster validation | `False` |
|
||||
| `--buffer_trimming` | Buffer trimming strategy (`sentence` or `segment`) | `segment` |
|
||||
|
||||
|
||||
|
||||
|
||||
> For diarization using Diart, you need to accept user conditions [here](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model, [here](https://huggingface.co/pyannote/segmentation-3.0) for the `pyannote/segmentation-3.0` model and [here](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model. **Then**, login to HuggingFace: `huggingface-cli login`
|
||||
|
||||
### 🚀 Deployment Guide
|
||||
|
||||
To deploy WhisperLiveKit in production:
|
||||
|
||||
1. **Server Setup**: Install production ASGI server & launch with multiple workers
|
||||
```bash
|
||||
pip install uvicorn gunicorn
|
||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||
```
|
||||
|
||||
2. **Frontend**: Host your customized version of the `html` example & ensure WebSocket connection points correctly
|
||||
|
||||
3. **Nginx Configuration** (recommended for production):
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com;
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}}
|
||||
```
|
||||
|
||||
4. **HTTPS Support**: For secure deployments, use "wss://" instead of "ws://" in WebSocket URL
|
||||
|
||||
## 🐋 Docker
|
||||
|
||||
Deploy the application easily using Docker with GPU or CPU support.
|
||||
|
||||
### Prerequisites
|
||||
- Docker installed on your system
|
||||
- For GPU support: NVIDIA Docker runtime installed
|
||||
|
||||
### Quick Start
|
||||
|
||||
**With GPU acceleration (recommended):**
|
||||
```bash
|
||||
docker build -t wlk .
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
- arecord sends realtime audio from a sound device (e.g. mic), in raw audio format -- 16000 sampling rate, mono channel, S16\_LE -- signed 16-bit integer low endian. (use the alternative to arecord that works for you)
|
||||
**CPU only:**
|
||||
```bash
|
||||
docker build -f Dockerfile.cpu -t wlk .
|
||||
docker run -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
- nc is netcat with server's host and port
|
||||
### Advanced Usage
|
||||
|
||||
**Custom configuration:**
|
||||
```bash
|
||||
# Example with custom model and language
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||
```
|
||||
|
||||
### Memory Requirements
|
||||
- **Large models**: Ensure your Docker runtime has sufficient memory allocated
|
||||
|
||||
|
||||
## Background
|
||||
|
||||
Default Whisper is intended for audio chunks of at most 30 seconds that contain
|
||||
one full sentence. Longer audio files must be split to shorter chunks and
|
||||
merged with "init prompt". In low latency simultaneous streaming mode, the
|
||||
simple and naive chunking fixed-sized windows does not work well, it can split
|
||||
a word in the middle. It is also necessary to know when the transcribt is
|
||||
stable, should be confirmed ("commited") and followed up, and when the future
|
||||
content makes the transcript clearer.
|
||||
|
||||
For that, there is LocalAgreement-n policy: if n consecutive updates, each with
|
||||
a newly available audio stream chunk, agree on a prefix transcript, it is
|
||||
confirmed. (Reference: CUNI-KIT at IWSLT 2022 etc.)
|
||||
|
||||
In this project, we re-use the idea of Peter Polák from this demo:
|
||||
https://github.com/pe-trik/transformers/blob/online_decode/examples/pytorch/online-decoding/whisper-online-demo.py
|
||||
However, it doesn't do any sentence segmentation, but Whisper produces
|
||||
punctuation and the libraries `faster-whisper` and `whisper_transcribed` make
|
||||
word-level timestamps. In short: we
|
||||
consecutively process new audio chunks, emit the transcripts that are confirmed
|
||||
by 2 iterations, and scroll the audio processing buffer on a timestamp of a
|
||||
confirmed complete sentence. The processing audio buffer is not too long and
|
||||
the processing is fast.
|
||||
|
||||
In more detail: we use the init prompt, we handle the inaccurate timestamps, we
|
||||
re-process confirmed sentence prefixes and skip them, making sure they don't
|
||||
overlap, and we limit the processing buffer window.
|
||||
|
||||
### Performance evaluation
|
||||
|
||||
[See the paper.](http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/pdf/2023.ijcnlp-demo.3.pdf)
|
||||
|
||||
### Contributions
|
||||
|
||||
Contributions are welcome. We acknowledge especially:
|
||||
|
||||
- [The GitHub contributors](https://github.com/ufal/whisper_streaming/graphs/contributors) for their pull requests with new features and bugfixes.
|
||||
- [The translation of this repo into Chinese.](https://github.com/Gloridust/whisper_streaming_CN)
|
||||
- [Ondřej Plátek](https://opla.cz/) for the paper pre-review.
|
||||
- [Peter Polák](https://ufal.mff.cuni.cz/peter-polak) for the original idea.
|
||||
- The UEDIN team of the [ELITR project](https://elitr.eu) for the original line_packet.py.
|
||||
|
||||
|
||||
## Contact
|
||||
|
||||
Dominik Macháček, machacek@ufal.mff.cuni.cz
|
||||
|
||||
#### Customization
|
||||
|
||||
- `--build-arg` Options:
|
||||
- `EXTRAS="whisper-timestamped"` - Add extras to the image's installation (no spaces). Remember to set necessary container options!
|
||||
- `HF_PRECACHE_DIR="./.cache/"` - Pre-load a model cache for faster first-time start
|
||||
- `HF_TKN_FILE="./token"` - Add your Hugging Face Hub access token to download gated models
|
||||
|
||||
## 🔮 Use Cases
|
||||
Capture discussions in real-time for meeting transcription, help hearing-impaired users follow conversations through accessibility tools, transcribe podcasts or videos automatically for content creation, transcribe support calls with speaker identification for customer service...
|
||||
|
||||
258
ReadmeJP.md
Normal file
@@ -0,0 +1,258 @@
|
||||
<h1 align="center">WhisperLiveKit</h1>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/demo.png" alt="WhisperLiveKit Demo" width="730">
|
||||
</p>
|
||||
|
||||
<p align="center"><b>話者識別機能付き、リアルタイム、完全ローカルな音声テキスト変換</b></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="PyPI Version" src="https://img.shields.io/pypi/v/whisperlivekit?color=g"></a>
|
||||
<a href="https://pepy.tech/project/whisperlivekit"><img alt="PyPI Downloads" src="https://static.pepy.tech/personalized-badge/whisperlivekit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=installations"></a>
|
||||
<a href="https://pypi.org/project/whisperlivekit/"><img alt="Python Versions" src="https://img.shields.io/badge/python-3.9--3.13-dark_green"></a>
|
||||
<a href="https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/badge/License-MIT/Dual Licensed-dark_green"></a>
|
||||
</p>
|
||||
|
||||
すぐに使えるバックエンド+サーバーとシンプルなフロントエンドで、リアルタイムの音声文字起こしをブラウザに直接提供します。✨
|
||||
|
||||
#### 主要な研究による技術:
|
||||
|
||||
- [SimulStreaming](https://github.com/ufal/SimulStreaming) (SOTA 2025) - AlignAttポリシーによる超低遅延文字起こし
|
||||
- [WhisperStreaming](https://github.com/ufal/whisper_streaming) (SOTA 2023) - LocalAgreementポリシーによる低遅延文字起こし
|
||||
- [Streaming Sortformer](https://arxiv.org/abs/2507.18446) (SOTA 2025) - 高度なリアルタイム話者ダイアライゼーション
|
||||
- [Diart](https://github.com/juanmc2005/diart) (SOTA 2021) - リアルタイム話者ダイアライゼーション
|
||||
- [Silero VAD](https://github.com/snakers4/silero-vad) (2024) - エンタープライズグレードの音声区間検出
|
||||
|
||||
> **なぜ各音声バッチで単純なWhisperモデルを実行しないのか?** Whisperは完全な発話向けに設計されており、リアルタイムのチャンク向けではありません。小さなセグメントを処理するとコンテキストが失われ、単語が音節の途中で途切れ、質の悪い文字起こしになります。WhisperLiveKitは、インテリジェントなバッファリングとインクリメンタルな処理のために、最先端の同時音声研究を利用しています。
|
||||
|
||||
### アーキテクチャ
|
||||
|
||||
<img alt="Architecture" src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/architecture.png" />
|
||||
|
||||
*バックエンドは複数の同時ユーザーをサポートします。音声が検出されない場合、音声区間検出がオーバーヘッドを削減します。*
|
||||
|
||||
### インストールとクイックスタート
|
||||
|
||||
```bash
|
||||
pip install whisperlivekit
|
||||
```
|
||||
|
||||
> **FFmpegが必要です** WhisperLiveKitを使用する前にインストールする必要があります。
|
||||
>
|
||||
> | OS | インストール方法 |
|
||||
> |-----------|-------------|
|
||||
> | Ubuntu/Debian | `sudo apt install ffmpeg` |
|
||||
> | MacOS | `brew install ffmpeg` |
|
||||
> | Windows | https://ffmpeg.org/download.html から.exeをダウンロードし、PATHに追加 |
|
||||
|
||||
#### クイックスタート
|
||||
1. **文字起こしサーバーを起動します:**
|
||||
```bash
|
||||
whisperlivekit-server --model base --language en
|
||||
```
|
||||
|
||||
2. **ブラウザを開き** `http://localhost:8000` にアクセスします。話し始めると、あなたの言葉がリアルタイムで表示されます!
|
||||
|
||||
|
||||
> - 利用可能なすべての言語のリストについては、[tokenizer.py](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py) を参照してください。
|
||||
> - HTTPSの要件については、**パラメータ**セクションのSSL設定オプションを参照してください。
|
||||
|
||||
#### オプションの依存関係
|
||||
|
||||
| オプション | `pip install` |
|
||||
|-----------|-------------|
|
||||
| **Sortformerによる話者ダイアライゼーション** | `git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]` |
|
||||
| Diartによる話者ダイアライゼーション | `diart` |
|
||||
| オリジナルのWhisperバックエンド | `whisper` |
|
||||
| タイムスタンプ改善バックエンド | `whisper-timestamped` |
|
||||
| Apple Silicon最適化バックエンド | `mlx-whisper` |
|
||||
| OpenAI APIバックエンド | `openai` |
|
||||
|
||||
それらの使用方法については、以下の**パラメータと設定**を参照してください。
|
||||
|
||||
### 使用例
|
||||
|
||||
**コマンドラインインターフェース**: 様々なオプションで文字起こしサーバーを起動します:
|
||||
|
||||
```bash
|
||||
# デフォルト(small)より良いモデルを使用
|
||||
whisperlivekit-server --model large-v3
|
||||
|
||||
# ダイアライゼーションと言語を指定した高度な設定
|
||||
whisperlivekit-server --host 0.0.0.0 --port 8000 --model medium --diarization --language fr
|
||||
```
|
||||
|
||||
**Python API連携**: 関数やクラスの使用方法のより完全な例については、[basic_server](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/basic_server.py) を確認してください。
|
||||
|
||||
```python
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, parse_args
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncio
|
||||
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(model="medium", diarization=True, lan="en")
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
async def handle_websocket_results(websocket: WebSocket, results_generator):
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response)
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global transcription_engine
|
||||
|
||||
# 接続ごとに新しいAudioProcessorを作成し、共有エンジンを渡す
|
||||
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
results_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
await websocket.accept()
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
```
|
||||
|
||||
**フロントエンド実装**: パッケージにはHTML/JavaScript実装が[ここ](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/web/live_transcription.html)に含まれています。`from whisperlivekit import get_web_interface_html` & `page = get_web_interface_html()` を使ってインポートすることもできます。
|
||||
|
||||
|
||||
## パラメータと設定
|
||||
|
||||
重要なパラメータのリストを変更できます。しかし、何を*変更すべき*でしょうか?
|
||||
- `--model` サイズ。リストと推奨事項は[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/available_models.md)
|
||||
- `--language`。リストは[こちら](https://github.com/QuentinFuxa/WhisperLiveKit/blob/main/whisperlivekit/simul_whisper/whisper/tokenizer.py)。`auto`を使用すると、モデルは自動的に言語を検出しようとしますが、英語に偏る傾向があります。
|
||||
- `--backend`? `simulstreaming`が正しく動作しない場合や、デュアルライセンス要件を避けたい場合は`--backend faster-whisper`に切り替えることができます。
|
||||
- `--warmup-file`、もしあれば
|
||||
- `--host`, `--port`, `--ssl-certfile`, `--ssl-keyfile`、サーバーをセットアップする場合
|
||||
- `--diarization`、使用したい場合。
|
||||
|
||||
残りは推奨しません。しかし、以下があなたのオプションです。
|
||||
|
||||
| パラメータ | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--model` | Whisperモデルのサイズ。 | `small` |
|
||||
| `--language` | ソース言語コードまたは`auto` | `auto` |
|
||||
| `--task` | `transcribe`または`translate` | `transcribe` |
|
||||
| `--backend` | 処理バックエンド | `simulstreaming` |
|
||||
| `--min-chunk-size` | 最小音声チャンクサイズ(秒) | `1.0` |
|
||||
| `--no-vac` | 音声アクティビティコントローラーを無効化 | `False` |
|
||||
| `--no-vad` | 音声区間検出を無効化 | `False` |
|
||||
| `--warmup-file` | モデルのウォームアップ用音声ファイルパス | `jfk.wav` |
|
||||
| `--host` | サーバーホストアドレス | `localhost` |
|
||||
| `--port` | サーバーポート | `8000` |
|
||||
| `--ssl-certfile` | SSL証明書ファイルへのパス(HTTPSサポート用) | `None` |
|
||||
| `--ssl-keyfile` | SSL秘密鍵ファイルへのパス(HTTPSサポート用) | `None` |
|
||||
|
||||
|
||||
| WhisperStreamingバックエンドオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--confidence-validation` | 高速な検証のために信頼スコアを使用 | `False` |
|
||||
| `--buffer_trimming` | バッファトリミング戦略(`sentence`または`segment`) | `segment` |
|
||||
|
||||
|
||||
| SimulStreamingバックエンドオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--frame-threshold` | AlignAttフレームしきい値(低いほど速く、高いほど正確) | `25` |
|
||||
| `--beams` | ビームサーチのビーム数(1 = 貪欲デコーディング) | `1` |
|
||||
| `--decoder` | デコーダタイプを強制(`beam`または`greedy`) | `auto` |
|
||||
| `--audio-max-len` | 最大音声バッファ長(秒) | `30.0` |
|
||||
| `--audio-min-len` | 処理する最小音声長(秒) | `0.0` |
|
||||
| `--cif-ckpt-path` | 単語境界検出用CIFモデルへのパス | `None` |
|
||||
| `--never-fire` | 未完了の単語を決して切り捨てない | `False` |
|
||||
| `--init-prompt` | モデルの初期プロンプト | `None` |
|
||||
| `--static-init-prompt` | スクロールしない静的プロンプト | `None` |
|
||||
| `--max-context-tokens` | 最大コンテキストトークン数 | `None` |
|
||||
| `--model-path` | .ptモデルファイルへの直接パス。見つからない場合はダウンロード | `./base.pt` |
|
||||
| `--preloaded-model-count` | オプション。メモリにプリロードするモデルの数(予想される同時ユーザー数まで設定) | `1` |
|
||||
|
||||
| ダイアライゼーションオプション | 説明 | デフォルト |
|
||||
|-----------|-------------|---------|
|
||||
| `--diarization` | 話者識別を有効化 | `False` |
|
||||
| `--diarization-backend` | `diart`または`sortformer` | `sortformer` |
|
||||
| `--segmentation-model` | DiartセグメンテーションモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `pyannote/segmentation-3.0` |
|
||||
| `--embedding-model` | Diart埋め込みモデルのHugging FaceモデルID。[利用可能なモデル](https://github.com/juanmc2005/diart/tree/main?tab=readme-ov-file#pre-trained-models) | `speechbrain/spkrec-ecapa-voxceleb` |
|
||||
|
||||
|
||||
> Diartを使用したダイアライゼーションには、pyannote.audioモデルへのアクセスが必要です:
|
||||
> 1. `pyannote/segmentation`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation)
|
||||
> 2. `pyannote/segmentation-3.0`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/segmentation-3.0)
|
||||
> 3. `pyannote/embedding`モデルの[ユーザー条件に同意](https://huggingface.co/pyannote/embedding)
|
||||
>4. HuggingFaceでログイン: `huggingface-cli login`
|
||||
|
||||
### 🚀 デプロイガイド
|
||||
|
||||
WhisperLiveKitを本番環境にデプロイするには:
|
||||
|
||||
1. **サーバーセットアップ**: 本番用ASGIサーバーをインストールし、複数のワーカーで起動します
|
||||
```bash
|
||||
pip install uvicorn gunicorn
|
||||
gunicorn -k uvicorn.workers.UvicornWorker -w 4 your_app:app
|
||||
```
|
||||
|
||||
2. **フロントエンド**: カスタマイズした`html`のバージョンをホストし、WebSocket接続が正しくポイントするようにします
|
||||
|
||||
3. **Nginx設定** (本番環境で推奨):
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com;
|
||||
location / {
|
||||
proxy_pass http://localhost:8000;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}}
|
||||
```
|
||||
|
||||
4. **HTTPSサポート**: 安全なデプロイメントのために、WebSocket URLで "ws://" の代わりに "wss://" を使用します
|
||||
|
||||
## 🐋 Docker
|
||||
|
||||
GPUまたはCPUサポート付きでDockerを使用してアプリケーションを簡単にデプロイします。
|
||||
|
||||
### 前提条件
|
||||
- Dockerがシステムにインストールされていること
|
||||
- GPUサポートの場合: NVIDIA Dockerランタイムがインストールされていること
|
||||
|
||||
### クイックスタート
|
||||
|
||||
**GPUアクセラレーション付き (推奨):**
|
||||
```bash
|
||||
docker build -t wlk .
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
**CPUのみ:**
|
||||
```bash
|
||||
docker build -f Dockerfile.cpu -t wlk .
|
||||
docker run -p 8000:8000 --name wlk wlk
|
||||
```
|
||||
|
||||
### 高度な使用法
|
||||
|
||||
**カスタム設定:**
|
||||
```bash
|
||||
# カスタムモデルと言語の例
|
||||
docker run --gpus all -p 8000:8000 --name wlk wlk --model large-v3 --language fr
|
||||
```
|
||||
|
||||
### メモリ要件
|
||||
- **大規模モデル**: Dockerランタイムに十分なメモリが割り当てられていることを確認してください
|
||||
|
||||
|
||||
#### カスタマイズ
|
||||
|
||||
- `--build-arg` オプション:
|
||||
- `EXTRAS="whisper-timestamped"` - イメージのインストールにエクストラを追加します(スペースなし)。必要なコンテナオプションを設定することを忘れないでください!
|
||||
- `HF_PRECACHE_DIR="./.cache/"` - 初回起動を高速化するためにモデルキャッシュをプリロードします
|
||||
- `HF_TKN_FILE="./token"` - ゲート付きモデルをダウンロードするためにHugging Face Hubアクセストークンを追加します
|
||||
|
||||
## 🔮 ユースケース
|
||||
会議の文字起こしのためにリアルタイムで議論をキャプチャする、聴覚障害のあるユーザーがアクセシビリティツールを通じて会話を追うのを助ける、コンテンツ作成のためにポッドキャストやビデオを自動的に文字起こしする、カスタマーサービスのために話者識別付きでサポートコールを文字起こしする...
|
||||
BIN
architecture.png
Normal file
|
After Width: | Height: | Size: 406 KiB |
19
chrome-extension/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
## WhisperLiveKit Chrome Extension v0.1.1
|
||||
Capture the audio of your current tab, transcribe diarize and translate it using WhisperliveKit, in Chrome and other Chromium-based browsers.
|
||||
|
||||
> Currently, only the tab audio is captured; your microphone audio is not recorded.
|
||||
|
||||
<img src="https://raw.githubusercontent.com/QuentinFuxa/WhisperLiveKit/refs/heads/main/chrome-extension/demo-extension.png" alt="WhisperLiveKit Demo" width="730">
|
||||
|
||||
## Running this extension
|
||||
1. Run `python sync_extension.py` to copy frontend files to the `chrome-extension` directory.
|
||||
2. Load the `chrome-extension` directory in Chrome as an unpacked extension.
|
||||
|
||||
|
||||
## Devs:
|
||||
- Impossible to capture audio from tabs if extension is a pannel, unfortunately:
|
||||
- https://issues.chromium.org/issues/40926394
|
||||
- https://groups.google.com/a/chromium.org/g/chromium-extensions/c/DET2SXCFnDg
|
||||
- https://issues.chromium.org/issues/40916430
|
||||
|
||||
- To capture microphone in an extension, there are tricks: https://github.com/justinmann/sidepanel-audio-issue , https://medium.com/@lynchee.owo/how-to-enable-microphone-access-in-chrome-extensions-by-code-924295170080 (comments)
|
||||
9
chrome-extension/background.js
Normal file
@@ -0,0 +1,9 @@
|
||||
chrome.runtime.onInstalled.addListener((details) => {
|
||||
if (details.reason.search(/install/g) === -1) {
|
||||
return
|
||||
}
|
||||
chrome.tabs.create({
|
||||
url: chrome.runtime.getURL("welcome.html"),
|
||||
active: true
|
||||
})
|
||||
})
|
||||
BIN
chrome-extension/demo-extension.png
Normal file
|
After Width: | Height: | Size: 5.8 MiB |
BIN
chrome-extension/icons/icon128.png
Normal file
|
After Width: | Height: | Size: 5.8 KiB |
BIN
chrome-extension/icons/icon16.png
Normal file
|
After Width: | Height: | Size: 376 B |
BIN
chrome-extension/icons/icon32.png
Normal file
|
After Width: | Height: | Size: 823 B |
BIN
chrome-extension/icons/icon48.png
Normal file
|
After Width: | Height: | Size: 1.4 KiB |
23
chrome-extension/manifest.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"manifest_version": 3,
|
||||
"name": "WhisperLiveKit Tab Capture",
|
||||
"version": "1.0",
|
||||
"description": "Capture and transcribe audio from browser tabs using WhisperLiveKit.",
|
||||
"icons": {
|
||||
"16": "icons/icon16.png",
|
||||
"32": "icons/icon32.png",
|
||||
"48": "icons/icon48.png",
|
||||
"128": "icons/icon128.png"
|
||||
},
|
||||
"action": {
|
||||
"default_title": "WhisperLiveKit Tab Capture",
|
||||
"default_popup": "live_transcription.html"
|
||||
},
|
||||
"permissions": [
|
||||
"scripting",
|
||||
"tabCapture",
|
||||
"offscreen",
|
||||
"activeTab",
|
||||
"storage"
|
||||
]
|
||||
}
|
||||
12
chrome-extension/requestPermissions.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Request Permissions</title>
|
||||
<script src="requestPermissions.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
This page exists to workaround an issue with Chrome that blocks permission
|
||||
requests from chrome extensions
|
||||
<button id="requestMicrophone">Request Microphone</button>
|
||||
</body>
|
||||
</html>
|
||||
17
chrome-extension/requestPermissions.js
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Requests user permission for microphone access.
|
||||
* @returns {Promise<void>} A Promise that resolves when permission is granted or rejects with an error.
|
||||
*/
|
||||
async function getUserPermission() {
|
||||
console.log("Getting user permission for microphone access...");
|
||||
await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state == "granted") {
|
||||
window.close();
|
||||
}
|
||||
}
|
||||
|
||||
// Call the function to request microphone permission
|
||||
getUserPermission();
|
||||
29
chrome-extension/sidepanel.js
Normal file
@@ -0,0 +1,29 @@
|
||||
console.log("sidepanel.js");
|
||||
|
||||
async function run() {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
|
||||
if (micPermission.state !== "granted") {
|
||||
chrome.tabs.create({ url: "requestPermissions.html" });
|
||||
}
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state === "granted") {
|
||||
document.getElementById(
|
||||
"audioPermission"
|
||||
).innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
void run();
|
||||
264
docs/API.md
Normal file
@@ -0,0 +1,264 @@
|
||||
# WhisperLiveKit WebSocket API Documentation
|
||||
|
||||
> !! **Note**: The new API structure described in this document is currently under deployment.
|
||||
This documentation is intended for devs who want to build custom frontends.
|
||||
|
||||
WLK provides real-time speech transcription, speaker diarization, and translation through a WebSocket API. The server sends incremental updates as audio is processed, allowing clients to display live transcription results with minimal latency.
|
||||
|
||||
---
|
||||
|
||||
## Legacy API (Current)
|
||||
|
||||
### Message Structure
|
||||
|
||||
The current API sends complete state snapshots on each update (several time per second)
|
||||
|
||||
```typescript
|
||||
{
|
||||
"type": str,
|
||||
"status": str,
|
||||
"lines": [
|
||||
{
|
||||
"speaker": int,
|
||||
"text": str,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"translation": str | null,
|
||||
"detected_language": str
|
||||
}
|
||||
],
|
||||
"buffer_transcription": str,
|
||||
"buffer_diarization": str,
|
||||
"remaining_time_transcription": float,
|
||||
"remaining_time_diarization": float
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## New API (Under Development)
|
||||
|
||||
### Philosophy
|
||||
|
||||
Principles:
|
||||
|
||||
- **Incremental Updates**: Only updates and new segments are sent
|
||||
- **Ephemeral Buffers**: Temporary, unvalidated data displayed in real-time but overwritten on next update, at speaker level
|
||||
|
||||
|
||||
## Message Format
|
||||
|
||||
|
||||
```typescript
|
||||
{
|
||||
"type": "transcript_update",
|
||||
"status": "active_transcription" | "no_audio_detected",
|
||||
"segments": [
|
||||
{
|
||||
"id": number,
|
||||
"speaker": number,
|
||||
"text": string,
|
||||
"start_speaker": float,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"language": string | null,
|
||||
"translation": string,
|
||||
"words": [
|
||||
{
|
||||
"text": string,
|
||||
"start": float,
|
||||
"end": float,
|
||||
"validated": {
|
||||
"text": boolean,
|
||||
"speaker": boolean,
|
||||
}
|
||||
}
|
||||
],
|
||||
"buffer": {
|
||||
"transcription": string,
|
||||
"diarization": string,
|
||||
"translation": string
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"remaining_time_transcription": float,
|
||||
"remaining_time_diarization": float
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Other Message Types
|
||||
|
||||
#### Config Message (sent on connection)
|
||||
```json
|
||||
{
|
||||
"type": "config",
|
||||
"useAudioWorklet": true / false
|
||||
}
|
||||
```
|
||||
|
||||
#### Ready to Stop Message (sent after processing complete)
|
||||
```json
|
||||
{
|
||||
"type": "ready_to_stop"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Field Descriptions
|
||||
|
||||
### Segment Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `id` | `number` | Unique identifier for this segment. Used by clients to update specific segments efficiently. |
|
||||
| `speaker` | `number` | Speaker ID (1, 2, 3...). Special value `-2` indicates silence. |
|
||||
| `text` | `string` | Validated transcription text for this update. Should be **appended** to the segment's text on the client side. |
|
||||
| `start_speaker` | `float` | Timestamp (seconds) when this speaker segment began. |
|
||||
| `start` | `float` | Timestamp (seconds) of the first word in this update. |
|
||||
| `end` | `float` | Timestamp (seconds) of the last word in this update. |
|
||||
| `language` | `string \| null` | ISO language code (e.g., "en", "fr"). `null` until language is detected. |
|
||||
| `translation` | `string` | Validated translation text for this update. Should be **appended** to the segment's translation on the client side. |
|
||||
| `words` | `Array` | Array of word-level objects with timing and validation information. |
|
||||
| `buffer` | `Object` | Per-segment temporary buffers, see below |
|
||||
|
||||
### Word Object
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `text` | `string` | The word text. |
|
||||
| `start` | `number` | Start timestamp (seconds) of this word. |
|
||||
| `end` | `number` | End timestamp (seconds) of this word. |
|
||||
| `validated.text` | `boolean` | Whether the transcription text has been validated. if false, word is also in buffer: transcription |
|
||||
| `validated.speaker` | `boolean` | Whether the speaker assignment has been validated. if false, word is also in buffer: diarization |
|
||||
| `validated.language` | `boolean` | Whether the language detection has been validated. if false, word is also in buffer: translation |
|
||||
|
||||
### Buffer Object (Per-Segment)
|
||||
|
||||
Buffers are **ephemeral**. They should be displayed to the user but not stored permanently in the frontend. Each update may contain a completely different buffer value, and previous buffer is likely to be in the next validated text.
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `transcription` | `string` | Pending transcription text. Displayed immediately but **overwritten** on next update. |
|
||||
| `diarization` | `string` | Pending diarization text (text waiting for speaker assignment). Displayed immediately but **overwritten** on next update. |
|
||||
| `translation` | `string` | Pending translation text. Displayed immediately but **overwritten** on next update. |
|
||||
|
||||
|
||||
### Metadata Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `remaining_time_transcription` | `float` | Seconds of audio waiting for transcription processing. |
|
||||
| `remaining_time_diarization` | `float` | Seconds of audio waiting for speaker diarization. |
|
||||
|
||||
### Status Values
|
||||
|
||||
| Status | Description |
|
||||
|--------|-------------|
|
||||
| `active_transcription` | Normal operation, transcription is active. |
|
||||
| `no_audio_detected` | No audio has been detected yet. |
|
||||
|
||||
---
|
||||
|
||||
## Update Behavior
|
||||
|
||||
### Incremental Updates
|
||||
|
||||
The API sends **only changed or new segments**. Clients should:
|
||||
|
||||
1. Maintain a local map of segments by ID
|
||||
2. When receiving an update, merge/update segments by ID
|
||||
3. Render only the changed segments
|
||||
|
||||
### Language Detection
|
||||
|
||||
When language is detected for a segment:
|
||||
|
||||
```jsonc
|
||||
// Update 1: No language yet
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "May see", "language": null}
|
||||
]
|
||||
}
|
||||
|
||||
// Update 2: Same segment ID, language now detected
|
||||
{
|
||||
"segments": [
|
||||
{"id": 1, "speaker": 1, "text": "Merci", "language": "fr"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Client behavior**: **Replace** the existing segment with the same ID.
|
||||
|
||||
### Buffer Behavior
|
||||
|
||||
Buffers are **per-segment** to handle multi-speaker scenarios correctly.
|
||||
|
||||
#### Example: Translation with diarization and translation
|
||||
|
||||
```jsonc
|
||||
// Update 1
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": "Hello world, how are",
|
||||
"translation": "",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " you on",
|
||||
"translation": "Bonjour le monde"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are <DIARIZATION BUFFER> you on</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION><TRANSLATION BUFFER>Bonjour le monde</TRANSLATION BUFFER></TRANSLATION>
|
||||
|
||||
|
||||
// Update 2
|
||||
{
|
||||
"segments": [
|
||||
{
|
||||
"id": 1,
|
||||
"speaker": 1,
|
||||
"text": " you on this",
|
||||
"translation": "Bonjour tout le monde",
|
||||
"buffer": {
|
||||
"transcription": "",
|
||||
"diarization": " beautiful day",
|
||||
"translation": ",comment"
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// ==== Frontend ====
|
||||
// <SPEAKER>1</SPEAKER>
|
||||
// <TRANSCRIPTION>Hello world, how are you on this<DIARIZATION BUFFER> beautiful day</DIARIZATION BUFFER></TRANSCRIPTION>
|
||||
// <TRANSLATION>Bonjour tout le monde<TRANSLATION BUFFER>, comment</TRANSLATION BUFFER><TRANSLATION>
|
||||
```
|
||||
|
||||
### Silence Segments
|
||||
|
||||
Silence is represented with the speaker id = `-2`:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"id": 5,
|
||||
"speaker": -2,
|
||||
"text": "",
|
||||
"start": 10.5,
|
||||
"end": 12.3
|
||||
}
|
||||
```
|
||||
71
docs/alignement_principles.md
Normal file
@@ -0,0 +1,71 @@
|
||||
### Alignment between STT Tokens and Diarization Segments
|
||||
|
||||
- Example 1: The punctuation from STT and the speaker change from Diariation come in the prediction `t`
|
||||
- Example 2: The punctuation from STT comes from prediction `t`, but the speaker change from Diariation come in the prediction `t-1`
|
||||
- Example 3: The punctuation from STT comes from prediction `t-1`, but the speaker change from Diariation come in the prediction `t`
|
||||
|
||||
> `#` Is the split between the `t-1` prediction and t prediction.
|
||||
|
||||
|
||||
## Example 1:
|
||||
```text
|
||||
punctuations_segments : __#_______.__________________!____
|
||||
diarization_segments:
|
||||
SPK1 __#____________
|
||||
SPK2 # ___________________
|
||||
-->
|
||||
ALIGNED SPK1 __#_______.
|
||||
ALIGNED SPK2 # __________________!____
|
||||
|
||||
t-1 output:
|
||||
SPK1: __#
|
||||
SPK2: NO
|
||||
DIARIZATION BUFFER: NO
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: __________________!____
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 2:
|
||||
```text
|
||||
punctuations_segments : _____#__.___________
|
||||
diarization_segments:
|
||||
SPK1 ___ #
|
||||
SPK2 __#______________
|
||||
-->
|
||||
ALIGNED SPK1 _____#__.
|
||||
ALIGNED SPK2 # ___________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___ #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: __#__.
|
||||
SPK2: ___________
|
||||
DIARIZATION BUFFER: No
|
||||
```
|
||||
|
||||
## Example 3:
|
||||
```text
|
||||
punctuations_segments : ___.__#__________
|
||||
diarization_segments:
|
||||
SPK1 ______#__
|
||||
SPK2 # ________
|
||||
-->
|
||||
ALIGNED SPK1 ___. #
|
||||
ALIGNED SPK2 __#__________
|
||||
|
||||
t-1 output:
|
||||
SPK1: ___. #
|
||||
SPK2:
|
||||
DIARIZATION BUFFER: __#
|
||||
|
||||
t output:
|
||||
SPK1: #
|
||||
SPK2: __#___________
|
||||
DIARIZATION BUFFER: NO
|
||||
```
|
||||
109
docs/available_models.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# Available Whisper model sizes:
|
||||
|
||||
- tiny.en (english only)
|
||||
- tiny
|
||||
- base.en (english only)
|
||||
- base
|
||||
- small.en (english only)
|
||||
- small
|
||||
- medium.en (english only)
|
||||
- medium
|
||||
- large-v1
|
||||
- large-v2
|
||||
- large-v3
|
||||
- large-v3-turbo
|
||||
|
||||
## How to choose?
|
||||
|
||||
### Language Support
|
||||
- **English only**: Use `.en` models for better accuracy and faster processing when you only need English transcription
|
||||
- **Multilingual**: Do not use `.en` models.
|
||||
|
||||
### Resource Constraints
|
||||
- **Limited GPU/CPU or need for very low latency**: Choose `small` or smaller models
|
||||
- `tiny`: Fastest, lowest resource usage, acceptable quality for simple audio
|
||||
- `base`: Good balance of speed and accuracy for basic use cases
|
||||
- `small`: Better accuracy while still being resource-efficient
|
||||
- **Good resources available**: Use `large` models for best accuracy
|
||||
- `large-v2`: Excellent accuracy, good multilingual support
|
||||
- `large-v3`: Best overall accuracy and language support
|
||||
|
||||
### Special Cases
|
||||
- **No translation needed**: Use `large-v3-turbo`
|
||||
- Same transcription quality as `large-v2` but significantly faster
|
||||
- **Important**: Does not translate correctly, only transcribes
|
||||
|
||||
### Model Comparison Table
|
||||
|
||||
| Model | Speed | Accuracy | Multilingual | Translation | Best Use Case |
|
||||
|-------|--------|----------|--------------|-------------|---------------|
|
||||
| tiny(.en) | Fastest | Basic | Yes/No | Yes/No | Real-time, low resources |
|
||||
| base(.en) | Fast | Good | Yes/No | Yes/No | Balanced performance |
|
||||
| small(.en) | Medium | Better | Yes/No | Yes/No | Quality on limited hardware |
|
||||
| medium(.en) | Slow | High | Yes/No | Yes/No | High quality, moderate resources |
|
||||
| large-v2 | Slowest | Excellent | Yes | Yes | Best overall quality |
|
||||
| large-v3 | Slowest | Excellent | Yes | Yes | Maximum accuracy |
|
||||
| large-v3-turbo | Fast | Excellent | Yes | No | Fast, high-quality transcription |
|
||||
|
||||
### Additional Considerations
|
||||
|
||||
**Model Performance**:
|
||||
- Accuracy improves significantly from tiny to large models
|
||||
- English-only models are ~10-15% more accurate for English audio
|
||||
- Newer versions (v2, v3) have better punctuation and formatting
|
||||
|
||||
**Hardware Requirements**:
|
||||
- `tiny`: ~1GB VRAM
|
||||
- `base`: ~1GB VRAM
|
||||
- `small`: ~2GB VRAM
|
||||
- `medium`: ~5GB VRAM
|
||||
- `large`: ~10GB VRAM
|
||||
- `large‑v3‑turbo`: ~6GB VRAM
|
||||
|
||||
**Audio Quality Impact**:
|
||||
- Clean, clear audio: smaller models may suffice
|
||||
- Noisy, accented, or technical audio: larger models recommended
|
||||
- Phone/low-quality audio: use at least `small` model
|
||||
|
||||
### Quick Decision Tree
|
||||
1. English only? → Add `.en` to your choice
|
||||
2. Limited resources or need speed? → `small` or smaller
|
||||
3. Good hardware and want best quality? → `large-v3`
|
||||
4. Need fast, high-quality transcription without translation? → `large-v3-turbo`
|
||||
5. Need translation capabilities? → `large-v2` or `large-v3` (avoid turbo)
|
||||
|
||||
|
||||
_______________________
|
||||
|
||||
# Translation Models and Backend
|
||||
|
||||
**Language Support**: ~200 languages
|
||||
|
||||
## Distilled Model Sizes Available
|
||||
|
||||
| Model | Size | Parameters | VRAM (FP16) | VRAM (INT8) | Quality |
|
||||
|-------|------|------------|-------------|-------------|---------|
|
||||
| 600M | 2.46 GB | 600M | ~1.5GB | ~800MB | Good, understandable |
|
||||
| 1.3B | 5.48 GB | 1.3B | ~3GB | ~1.5GB | Better accuracy, context |
|
||||
|
||||
**Quality Impact**: 1.3B has ~15-25% better BLEU scores vs 600M across language pairs.
|
||||
|
||||
## Backend Performance
|
||||
|
||||
| Backend | Speed vs Base | Memory Usage | Quality Loss |
|
||||
|---------|---------------|--------------|--------------|
|
||||
| CTranslate2 | 6-10x faster | 40-60% less | ~5% BLEU drop |
|
||||
| Transformers | Baseline | High | None |
|
||||
| Transformers + MPS (on Apple Silicon) | 2x faster | Medium | None |
|
||||
|
||||
**Metrics**:
|
||||
- CTranslate2: 50-100+ tokens/sec
|
||||
- Transformers: 10-30 tokens/sec
|
||||
- Apple Silicon with MPS: Up to 2x faster than CTranslate2
|
||||
|
||||
## Quick Decision Matrix
|
||||
|
||||
**Choose 600M**: Limited resources, close to 0 lag
|
||||
**Choose 1.3B**: Quality matters
|
||||
**Choose Transformers**: On Apple Silicon
|
||||
|
||||
19
docs/models_compatible_formats.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Model Path Formats
|
||||
|
||||
The `--model-path` parameter accepts:
|
||||
|
||||
## File Path
|
||||
- **`.pt` / `.bin` / `.safetensor` formats** Should be openable by pytorch/safetensor.
|
||||
|
||||
## Directory Path (recommended)
|
||||
Must contain:
|
||||
- **`.pt` / `.bin` / `.safetensor` file** (required for decoder)
|
||||
|
||||
May optionally contain:
|
||||
- **`.bin` file** - faster-whisper model for encoder (requires faster-whisper)
|
||||
- **`weights.npz`** or **`weights.safetensors`** - for encoder (requires whisper-mlx)
|
||||
|
||||
## Hugging Face Repo ID
|
||||
- Provide the repo ID (e.g. `openai/whisper-large-v3`) and WhisperLiveKit will download and cache the snapshot automatically. For gated repos, authenticate via `huggingface-cli login` first.
|
||||
|
||||
To improve speed/reduce allucinations, you may want to use `scripts/determine_alignment_heads.py` to determine the alignment heads to use for your model, and use the `--custom-alignment-heads` to pass them to WLK. If not, alignement heads are set to be all the heads of the last half layer of decoder.
|
||||
265
docs/supported_languages.md
Normal file
@@ -0,0 +1,265 @@
|
||||
# Supported Languages
|
||||
|
||||
WhisperLiveKit supports translation into **201 languages** from the FLORES-200 dataset through the NLLB (No Language Left Behind) translation system.
|
||||
|
||||
## How to Specify Languages
|
||||
|
||||
You can specify languages in **three different ways**:
|
||||
|
||||
1. **Language Name** (case-insensitive): `"English"`, `"French"`, `"Spanish"`
|
||||
2. **ISO Language Code**: `"en"`, `"fr"`, `"es"`
|
||||
3. **NLLB Code** (FLORES-200): `"eng_Latn"`, `"fra_Latn"`, `"spa_Latn"`
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Command Line
|
||||
```bash
|
||||
# Using language name
|
||||
whisperlivekit-server --target-language "French"
|
||||
|
||||
# Using ISO code
|
||||
whisperlivekit-server --target-language fr
|
||||
|
||||
# Using NLLB code
|
||||
whisperlivekit-server --target-language fra_Latn
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from nllw.translation import get_language_info
|
||||
|
||||
# Get language information by name
|
||||
lang_info = get_language_info("French")
|
||||
print(lang_info)
|
||||
# {'name': 'French', 'nllb': 'fra_Latn', 'language_code': 'fr'}
|
||||
|
||||
# Get language information by ISO code
|
||||
lang_info = get_language_info("fr")
|
||||
|
||||
# Get language information by NLLB code
|
||||
lang_info = get_language_info("fra_Latn")
|
||||
|
||||
# All three return the same result
|
||||
```
|
||||
|
||||
## Complete Language List
|
||||
|
||||
The following table lists all 201 supported languages with their corresponding codes:
|
||||
|
||||
| Language Name | ISO Code | NLLB Code |
|
||||
|---------------|----------|-----------|
|
||||
| Acehnese (Arabic script) | ace_Arab | ace_Arab |
|
||||
| Acehnese (Latin script) | ace_Latn | ace_Latn |
|
||||
| Mesopotamian Arabic | acm_Arab | acm_Arab |
|
||||
| Ta'izzi-Adeni Arabic | acq_Arab | acq_Arab |
|
||||
| Tunisian Arabic | aeb_Arab | aeb_Arab |
|
||||
| Afrikaans | af | afr_Latn |
|
||||
| South Levantine Arabic | ajp_Arab | ajp_Arab |
|
||||
| Akan | ak | aka_Latn |
|
||||
| Tosk Albanian | als | als_Latn |
|
||||
| Amharic | am | amh_Ethi |
|
||||
| North Levantine Arabic | apc_Arab | apc_Arab |
|
||||
| Modern Standard Arabic | ar | arb_Arab |
|
||||
| Modern Standard Arabic (Romanized) | arb_Latn | arb_Latn |
|
||||
| Najdi Arabic | ars_Arab | ars_Arab |
|
||||
| Moroccan Arabic | ary_Arab | ary_Arab |
|
||||
| Egyptian Arabic | arz_Arab | arz_Arab |
|
||||
| Assamese | as | asm_Beng |
|
||||
| Asturian | ast | ast_Latn |
|
||||
| Awadhi | awa | awa_Deva |
|
||||
| Central Aymara | ay | ayr_Latn |
|
||||
| South Azerbaijani | azb | azb_Arab |
|
||||
| North Azerbaijani | az | azj_Latn |
|
||||
| Bashkir | ba | bak_Cyrl |
|
||||
| Bambara | bm | bam_Latn |
|
||||
| Balinese | ban | ban_Latn |
|
||||
| Belarusian | be | bel_Cyrl |
|
||||
| Bemba | bem | bem_Latn |
|
||||
| Bengali | bn | ben_Beng |
|
||||
| Bhojpuri | bho | bho_Deva |
|
||||
| Banjar (Arabic script) | bjn_Arab | bjn_Arab |
|
||||
| Banjar (Latin script) | bjn_Latn | bjn_Latn |
|
||||
| Standard Tibetan | bo | bod_Tibt |
|
||||
| Bosnian | bs | bos_Latn |
|
||||
| Buginese | bug | bug_Latn |
|
||||
| Bulgarian | bg | bul_Cyrl |
|
||||
| Catalan | ca | cat_Latn |
|
||||
| Cebuano | ceb | ceb_Latn |
|
||||
| Czech | cs | ces_Latn |
|
||||
| Chokwe | cjk | cjk_Latn |
|
||||
| Central Kurdish | ckb | ckb_Arab |
|
||||
| Crimean Tatar | crh | crh_Latn |
|
||||
| Welsh | cy | cym_Latn |
|
||||
| Danish | da | dan_Latn |
|
||||
| German | de | deu_Latn |
|
||||
| Southwestern Dinka | dik | dik_Latn |
|
||||
| Dyula | dyu | dyu_Latn |
|
||||
| Dzongkha | dz | dzo_Tibt |
|
||||
| Greek | el | ell_Grek |
|
||||
| English | en | eng_Latn |
|
||||
| Esperanto | eo | epo_Latn |
|
||||
| Estonian | et | est_Latn |
|
||||
| Basque | eu | eus_Latn |
|
||||
| Ewe | ee | ewe_Latn |
|
||||
| Faroese | fo | fao_Latn |
|
||||
| Fijian | fj | fij_Latn |
|
||||
| Finnish | fi | fin_Latn |
|
||||
| Fon | fon | fon_Latn |
|
||||
| French | fr | fra_Latn |
|
||||
| Friulian | fur-IT | fur_Latn |
|
||||
| Nigerian Fulfulde | fuv | fuv_Latn |
|
||||
| West Central Oromo | om | gaz_Latn |
|
||||
| Scottish Gaelic | gd | gla_Latn |
|
||||
| Irish | ga-IE | gle_Latn |
|
||||
| Galician | gl | glg_Latn |
|
||||
| Guarani | gn | grn_Latn |
|
||||
| Gujarati | gu-IN | guj_Gujr |
|
||||
| Haitian Creole | ht | hat_Latn |
|
||||
| Hausa | ha | hau_Latn |
|
||||
| Hebrew | he | heb_Hebr |
|
||||
| Hindi | hi | hin_Deva |
|
||||
| Chhattisgarhi | hne | hne_Deva |
|
||||
| Croatian | hr | hrv_Latn |
|
||||
| Hungarian | hu | hun_Latn |
|
||||
| Armenian | hy-AM | hye_Armn |
|
||||
| Igbo | ig | ibo_Latn |
|
||||
| Ilocano | ilo | ilo_Latn |
|
||||
| Indonesian | id | ind_Latn |
|
||||
| Icelandic | is | isl_Latn |
|
||||
| Italian | it | ita_Latn |
|
||||
| Javanese | jv | jav_Latn |
|
||||
| Japanese | ja | jpn_Jpan |
|
||||
| Kabyle | kab | kab_Latn |
|
||||
| Jingpho | kac | kac_Latn |
|
||||
| Kamba | kam | kam_Latn |
|
||||
| Kannada | kn | kan_Knda |
|
||||
| Kashmiri (Arabic script) | kas_Arab | kas_Arab |
|
||||
| Kashmiri (Devanagari script) | kas_Deva | kas_Deva |
|
||||
| Georgian | ka | kat_Geor |
|
||||
| Kazakh | kk | kaz_Cyrl |
|
||||
| Kabiyè | kbp | kbp_Latn |
|
||||
| Kabuverdianu | kea | kea_Latn |
|
||||
| Halh Mongolian | mn | khk_Cyrl |
|
||||
| Khmer | km | khm_Khmr |
|
||||
| Kikuyu | ki | kik_Latn |
|
||||
| Kinyarwanda | rw | kin_Latn |
|
||||
| Kyrgyz | ky | kir_Cyrl |
|
||||
| Kimbundu | kmb | kmb_Latn |
|
||||
| Northern Kurdish | kmr | kmr_Latn |
|
||||
| Central Kanuri (Arabic script) | knc_Arab | knc_Arab |
|
||||
| Central Kanuri (Latin script) | knc_Latn | knc_Latn |
|
||||
| Kikongo | kg | kon_Latn |
|
||||
| Korean | ko | kor_Hang |
|
||||
| Lao | lo | lao_Laoo |
|
||||
| Ligurian | lij | lij_Latn |
|
||||
| Limburgish | li | lim_Latn |
|
||||
| Lingala | ln | lin_Latn |
|
||||
| Lithuanian | lt | lit_Latn |
|
||||
| Lombard | lmo | lmo_Latn |
|
||||
| Latgalian | ltg | ltg_Latn |
|
||||
| Luxembourgish | lb | ltz_Latn |
|
||||
| Luba-Kasai | lua | lua_Latn |
|
||||
| Ganda | lg | lug_Latn |
|
||||
| Luo | luo | luo_Latn |
|
||||
| Mizo | lus | lus_Latn |
|
||||
| Standard Latvian | lv | lvs_Latn |
|
||||
| Magahi | mag | mag_Deva |
|
||||
| Maithili | mai | mai_Deva |
|
||||
| Malayalam | ml-IN | mal_Mlym |
|
||||
| Marathi | mr | mar_Deva |
|
||||
| Minangkabau (Arabic script) | min_Arab | min_Arab |
|
||||
| Minangkabau (Latin script) | min_Latn | min_Latn |
|
||||
| Macedonian | mk | mkd_Cyrl |
|
||||
| Maltese | mt | mlt_Latn |
|
||||
| Meitei (Bengali script) | mni | mni_Beng |
|
||||
| Mossi | mos | mos_Latn |
|
||||
| Maori | mi | mri_Latn |
|
||||
| Burmese | my | mya_Mymr |
|
||||
| Dutch | nl | nld_Latn |
|
||||
| Norwegian Nynorsk | nn-NO | nno_Latn |
|
||||
| Norwegian Bokmål | nb | nob_Latn |
|
||||
| Nepali | ne-NP | npi_Deva |
|
||||
| Northern Sotho | nso | nso_Latn |
|
||||
| Nuer | nus | nus_Latn |
|
||||
| Nyanja | ny | nya_Latn |
|
||||
| Occitan | oc | oci_Latn |
|
||||
| Odia | or | ory_Orya |
|
||||
| Pangasinan | pag | pag_Latn |
|
||||
| Eastern Panjabi | pa | pan_Guru |
|
||||
| Papiamento | pap | pap_Latn |
|
||||
| Southern Pashto | pbt | pbt_Arab |
|
||||
| Western Persian | fa | pes_Arab |
|
||||
| Plateau Malagasy | mg | plt_Latn |
|
||||
| Polish | pl | pol_Latn |
|
||||
| Portuguese | pt-PT | por_Latn |
|
||||
| Dari | fa-AF | prs_Arab |
|
||||
| Ayacucho Quechua | qu | quy_Latn |
|
||||
| Romanian | ro | ron_Latn |
|
||||
| Rundi | rn | run_Latn |
|
||||
| Russian | ru | rus_Cyrl |
|
||||
| Sango | sg | sag_Latn |
|
||||
| Sanskrit | sa | san_Deva |
|
||||
| Santali | sat | sat_Olck |
|
||||
| Sicilian | scn | scn_Latn |
|
||||
| Shan | shn | shn_Mymr |
|
||||
| Sinhala | si-LK | sin_Sinh |
|
||||
| Slovak | sk | slk_Latn |
|
||||
| Slovenian | sl | slv_Latn |
|
||||
| Samoan | sm | smo_Latn |
|
||||
| Shona | sn | sna_Latn |
|
||||
| Sindhi | sd | snd_Arab |
|
||||
| Somali | so | som_Latn |
|
||||
| Southern Sotho | st | sot_Latn |
|
||||
| Spanish | es-ES | spa_Latn |
|
||||
| Sardinian | sc | srd_Latn |
|
||||
| Serbian | sr | srp_Cyrl |
|
||||
| Swati | ss | ssw_Latn |
|
||||
| Sundanese | su | sun_Latn |
|
||||
| Swedish | sv-SE | swe_Latn |
|
||||
| Swahili | sw | swh_Latn |
|
||||
| Silesian | szl | szl_Latn |
|
||||
| Tamil | ta | tam_Taml |
|
||||
| Tamasheq (Latin script) | taq_Latn | taq_Latn |
|
||||
| Tamasheq (Tifinagh script) | taq_Tfng | taq_Tfng |
|
||||
| Tatar | tt-RU | tat_Cyrl |
|
||||
| Telugu | te | tel_Telu |
|
||||
| Tajik | tg | tgk_Cyrl |
|
||||
| Tagalog | tl | tgl_Latn |
|
||||
| Thai | th | tha_Thai |
|
||||
| Tigrinya | ti | tir_Ethi |
|
||||
| Tok Pisin | tpi | tpi_Latn |
|
||||
| Tswana | tn | tsn_Latn |
|
||||
| Tsonga | ts | tso_Latn |
|
||||
| Turkmen | tk | tuk_Latn |
|
||||
| Tumbuka | tum | tum_Latn |
|
||||
| Turkish | tr | tur_Latn |
|
||||
| Twi | tw | twi_Latn |
|
||||
| Central Atlas Tamazight | tzm | tzm_Tfng |
|
||||
| Uyghur | ug | uig_Arab |
|
||||
| Ukrainian | uk | ukr_Cyrl |
|
||||
| Umbundu | umb | umb_Latn |
|
||||
| Urdu | ur | urd_Arab |
|
||||
| Northern Uzbek | uz | uzn_Latn |
|
||||
| Venetian | vec | vec_Latn |
|
||||
| Vietnamese | vi | vie_Latn |
|
||||
| Waray | war | war_Latn |
|
||||
| Wolof | wo | wol_Latn |
|
||||
| Xhosa | xh | xho_Latn |
|
||||
| Eastern Yiddish | yi | ydd_Hebr |
|
||||
| Yoruba | yo | yor_Latn |
|
||||
| Yue Chinese | yue | yue_Hant |
|
||||
| Chinese (Simplified) | zh-CN | zho_Hans |
|
||||
| Chinese (Traditional) | zh-TW | zho_Hant |
|
||||
| Standard Malay | ms | zsm_Latn |
|
||||
| Zulu | zu | zul_Latn |
|
||||
|
||||
## Special Features
|
||||
|
||||
### Multiple Script Support
|
||||
Several languages are available in multiple scripts (e.g., Arabic and Latin):
|
||||
- **Acehnese**: Arabic (`ace_Arab`) and Latin (`ace_Latn`)
|
||||
- **Banjar**: Arabic (`bjn_Arab`) and Latin (`bjn_Latn`)
|
||||
- **Kashmiri**: Arabic (`kas_Arab`) and Devanagari (`kas_Deva`)
|
||||
- **Minangkabau**: Arabic (`min_Arab`) and Latin (`min_Latn`)
|
||||
- **Tamasheq**: Latin (`taq_Latn`) and Tifinagh (`taq_Tfng`)
|
||||
- **Central Kanuri**: Arabic (`knc_Arab`) and Latin (`knc_Latn`)
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""Functions for sending and receiving individual lines of text over a socket.
|
||||
|
||||
A line is transmitted using one or more fixed-size packets of UTF-8 bytes
|
||||
containing:
|
||||
|
||||
- Zero or more bytes of UTF-8, excluding \n and \0, followed by
|
||||
|
||||
- Zero or more \0 bytes as required to pad the packet to PACKET_SIZE
|
||||
|
||||
Originally from the UEDIN team of the ELITR project.
|
||||
"""
|
||||
|
||||
PACKET_SIZE = 65536
|
||||
|
||||
|
||||
def send_one_line(socket, text):
|
||||
"""Sends a line of text over the given socket.
|
||||
|
||||
The 'text' argument should contain a single line of text (line break
|
||||
characters are optional). Line boundaries are determined by Python's
|
||||
str.splitlines() function [1]. We also count '\0' as a line terminator.
|
||||
If 'text' contains multiple lines then only the first will be sent.
|
||||
|
||||
If the send fails then an exception will be raised.
|
||||
|
||||
[1] https://docs.python.org/3.5/library/stdtypes.html#str.splitlines
|
||||
|
||||
Args:
|
||||
socket: a socket object.
|
||||
text: string containing a line of text for transmission.
|
||||
"""
|
||||
text.replace('\0', '\n')
|
||||
lines = text.splitlines()
|
||||
first_line = '' if len(lines) == 0 else lines[0]
|
||||
# TODO Is there a better way of handling bad input than 'replace'?
|
||||
data = first_line.encode('utf-8', errors='replace') + b'\n\0'
|
||||
for offset in range(0, len(data), PACKET_SIZE):
|
||||
bytes_remaining = len(data) - offset
|
||||
if bytes_remaining < PACKET_SIZE:
|
||||
padding_length = PACKET_SIZE - bytes_remaining
|
||||
packet = data[offset:] + b'\0' * padding_length
|
||||
else:
|
||||
packet = data[offset:offset+PACKET_SIZE]
|
||||
socket.sendall(packet)
|
||||
|
||||
|
||||
def receive_one_line(socket):
|
||||
"""Receives a line of text from the given socket.
|
||||
|
||||
This function will (attempt to) receive a single line of text. If data is
|
||||
currently unavailable then it will block until data becomes available or
|
||||
the sender has closed the connection (in which case it will return an
|
||||
empty string).
|
||||
|
||||
The string should not contain any newline characters, but if it does then
|
||||
only the first line will be returned.
|
||||
|
||||
Args:
|
||||
socket: a socket object.
|
||||
|
||||
Returns:
|
||||
A string representing a single line with a terminating newline or
|
||||
None if the connection has been closed.
|
||||
"""
|
||||
data = b''
|
||||
while True:
|
||||
packet = socket.recv(PACKET_SIZE)
|
||||
if not packet: # Connection has been closed.
|
||||
return None
|
||||
data += packet
|
||||
if b'\0' in packet:
|
||||
break
|
||||
# TODO Is there a better way of handling bad input than 'replace'?
|
||||
text = data.decode('utf-8', errors='replace').strip('\0')
|
||||
lines = text.split('\n')
|
||||
return lines[0] + '\n'
|
||||
|
||||
|
||||
def receive_lines(socket):
|
||||
try:
|
||||
data = socket.recv(PACKET_SIZE)
|
||||
except BlockingIOError:
|
||||
return []
|
||||
if data is None: # Connection has been closed.
|
||||
return None
|
||||
# TODO Is there a better way of handling bad input than 'replace'?
|
||||
text = data.decode('utf-8', errors='replace').strip('\0')
|
||||
lines = text.split('\n')
|
||||
if len(lines)==1 and not lines[0]:
|
||||
return None
|
||||
return lines
|
||||
70
pyproject.toml
Normal file
@@ -0,0 +1,70 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "whisperlivekit"
|
||||
version = "0.2.14.post4"
|
||||
description = "Real-time speech-to-text with speaker diarization using Whisper"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Quentin Fuxa" }
|
||||
]
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
"Programming Language :: Python :: 3.15",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Speech"
|
||||
]
|
||||
dependencies = [
|
||||
"fastapi",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"uvicorn",
|
||||
"websockets",
|
||||
"torchaudio>=2.0.0",
|
||||
"torch>=2.0.0",
|
||||
"huggingface-hub>=0.25.0",
|
||||
"tqdm",
|
||||
"tiktoken",
|
||||
'triton>=2.0.0; platform_machine == "x86_64" and (sys_platform == "linux" or sys_platform == "linux2")'
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
translation = ["nllw"]
|
||||
sentence_tokenizer = ["mosestokenizer", "wtpsplit"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/QuentinFuxa/WhisperLiveKit"
|
||||
|
||||
[project.scripts]
|
||||
whisperlivekit-server = "whisperlivekit.basic_server:main"
|
||||
wlk = "whisperlivekit.basic_server:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = [
|
||||
"whisperlivekit",
|
||||
"whisperlivekit.diarization",
|
||||
"whisperlivekit.simul_whisper",
|
||||
"whisperlivekit.whisper",
|
||||
"whisperlivekit.whisper.assets",
|
||||
"whisperlivekit.whisper.normalizers",
|
||||
"whisperlivekit.web",
|
||||
"whisperlivekit.local_agreement",
|
||||
"whisperlivekit.vad_models"
|
||||
]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
whisperlivekit = ["web/*.html", "web/*.css", "web/*.js", "web/src/*.svg"]
|
||||
"whisperlivekit.whisper.assets" = ["*.tiktoken", "*.npz"]
|
||||
"whisperlivekit.vad_models" = ["*.jit", "*.onnx"]
|
||||
BIN
scripts/alignment_heads.png
Normal file
|
After Width: | Height: | Size: 276 KiB |
153
scripts/convert_hf_whisper.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Convert a Hugging Face style Whisper checkpoint into a WhisperLiveKit .pt file.
|
||||
|
||||
Optionally shrink the supported audio chunk length (in seconds) by trimming the
|
||||
encoder positional embeddings and updating the stored model dimensions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from whisperlivekit.whisper.audio import HOP_LENGTH, SAMPLE_RATE
|
||||
from whisperlivekit.whisper.model import ModelDimensions
|
||||
from whisperlivekit.whisper.utils import exact_div
|
||||
from whisperlivekit.whisper import _convert_hf_state_dict
|
||||
|
||||
|
||||
def _load_state_dict(repo_path: Path) -> Dict[str, torch.Tensor]:
|
||||
safetensor_path = repo_path / "model.safetensors"
|
||||
bin_path = repo_path / "pytorch_model.bin"
|
||||
|
||||
if safetensor_path.is_file():
|
||||
try:
|
||||
from safetensors.torch import load_file # type: ignore
|
||||
except Exception as exc: # pragma: no cover - import guard
|
||||
raise RuntimeError(
|
||||
"Install safetensors to load model.safetensors "
|
||||
"(pip install safetensors)"
|
||||
) from exc
|
||||
return load_file(str(safetensor_path))
|
||||
|
||||
if bin_path.is_file():
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find model.safetensors or pytorch_model.bin under {repo_path}"
|
||||
)
|
||||
|
||||
|
||||
def _load_config(repo_path: Path) -> Dict:
|
||||
config_path = repo_path / "config.json"
|
||||
if not config_path.is_file():
|
||||
raise FileNotFoundError(
|
||||
f"Hugging Face checkpoint at {repo_path} is missing config.json"
|
||||
)
|
||||
with open(config_path, "r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def _derive_audio_ctx(chunk_length: float) -> Tuple[int, int]:
|
||||
n_samples = int(round(chunk_length * SAMPLE_RATE))
|
||||
expected_samples = chunk_length * SAMPLE_RATE
|
||||
if abs(n_samples - expected_samples) > 1e-6:
|
||||
raise ValueError(
|
||||
"chunk_length must align with sample rate so that "
|
||||
"chunk_length * SAMPLE_RATE is an integer"
|
||||
)
|
||||
n_frames = exact_div(n_samples, HOP_LENGTH)
|
||||
n_audio_ctx = exact_div(n_frames, 2)
|
||||
return n_frames, n_audio_ctx
|
||||
|
||||
|
||||
def _build_dims(config: Dict, chunk_length: float) -> Dict:
|
||||
base_dims = ModelDimensions(
|
||||
n_mels=config["num_mel_bins"],
|
||||
n_audio_ctx=config["max_source_positions"],
|
||||
n_audio_state=config["d_model"],
|
||||
n_audio_head=config["encoder_attention_heads"],
|
||||
n_audio_layer=config.get("encoder_layers") or config["num_hidden_layers"],
|
||||
n_vocab=config["vocab_size"],
|
||||
n_text_ctx=config["max_target_positions"],
|
||||
n_text_state=config["d_model"],
|
||||
n_text_head=config["decoder_attention_heads"],
|
||||
n_text_layer=config["decoder_layers"],
|
||||
).__dict__.copy()
|
||||
|
||||
_, n_audio_ctx = _derive_audio_ctx(chunk_length)
|
||||
base_dims["n_audio_ctx"] = n_audio_ctx
|
||||
base_dims["chunk_length"] = chunk_length
|
||||
return base_dims
|
||||
|
||||
|
||||
def _trim_positional_embedding(
|
||||
state_dict: Dict[str, torch.Tensor], target_ctx: int
|
||||
) -> None:
|
||||
key = "encoder.positional_embedding"
|
||||
if key not in state_dict:
|
||||
raise KeyError(f"{key} missing from converted state dict")
|
||||
|
||||
tensor = state_dict[key]
|
||||
if tensor.shape[0] < target_ctx:
|
||||
raise ValueError(
|
||||
f"Cannot increase encoder ctx from {tensor.shape[0]} to {target_ctx}"
|
||||
)
|
||||
if tensor.shape[0] == target_ctx:
|
||||
return
|
||||
state_dict[key] = tensor[:target_ctx].contiguous()
|
||||
|
||||
|
||||
def convert_checkpoint(hf_path: Path, output_path: Path, chunk_length: float) -> None:
|
||||
state_dict = _load_state_dict(hf_path)
|
||||
converted = _convert_hf_state_dict(state_dict)
|
||||
|
||||
config = _load_config(hf_path)
|
||||
dims = _build_dims(config, chunk_length)
|
||||
|
||||
_trim_positional_embedding(converted, dims["n_audio_ctx"])
|
||||
|
||||
package = {"dims": dims, "model_state_dict": converted}
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(package, output_path)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Hugging Face Whisper checkpoint to WhisperLiveKit format."
|
||||
)
|
||||
parser.add_argument(
|
||||
"hf_path",
|
||||
type=str,
|
||||
help="Path to the cloned Hugging Face repository (e.g. whisper-tiny.en)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="converted-whisper.pt",
|
||||
help="Destination path for the .pt file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk-length",
|
||||
type=float,
|
||||
default=30.0,
|
||||
help="Audio chunk length in seconds to support (default: 30)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
hf_path = Path(os.path.expanduser(args.hf_path)).resolve()
|
||||
output_path = Path(os.path.expanduser(args.output)).resolve()
|
||||
|
||||
convert_checkpoint(hf_path, output_path, args.chunk_length)
|
||||
print(f"Saved converted checkpoint to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
292
scripts/determine_alignment_heads.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Determine alignment heads for a variants, such as distilled model"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import gzip
|
||||
import io
|
||||
import pathlib
|
||||
import sys
|
||||
import math
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Audio as DatasetAudio, load_dataset
|
||||
import soundfile as sf
|
||||
import matplotlib.pyplot as plt
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
||||
WHISPER_ROOT = REPO_ROOT / "whisper"
|
||||
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
sys.path.insert(0, str(WHISPER_ROOT))
|
||||
|
||||
from whisper import load_model
|
||||
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from whisper.tokenizer import get_tokenizer
|
||||
|
||||
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
|
||||
|
||||
|
||||
def load_dataset_clips(name, config, split, limit):
|
||||
ds = load_dataset(name, config, split=split)
|
||||
ds = ds.cast_column("audio", DatasetAudio(decode=False))
|
||||
clips = []
|
||||
for idx, row in enumerate(ds):
|
||||
if limit is not None and idx >= limit:
|
||||
break
|
||||
audio_field = row["audio"]
|
||||
transcript = row["text"]
|
||||
|
||||
waveform_np, _ = sf.read(io.BytesIO(audio_field["bytes"]), dtype="float32")
|
||||
if waveform_np.ndim > 1:
|
||||
waveform_np = waveform_np.mean(axis=1)
|
||||
waveform = waveform_np
|
||||
transcript = str(transcript)
|
||||
|
||||
clips.append((waveform, transcript))
|
||||
return clips
|
||||
|
||||
|
||||
def load_clips(args):
|
||||
return load_dataset_clips(
|
||||
args.dataset,
|
||||
args.dataset_config,
|
||||
args.dataset_split,
|
||||
args.dataset_num_samples,
|
||||
)
|
||||
|
||||
|
||||
def _waveform_from_source(source: AudioInput) -> torch.Tensor:
|
||||
waveform = torch.from_numpy(source.astype(np.float32, copy=False))
|
||||
return waveform
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="pytorch_model.bin",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Torch device to run on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="librispeech_asr"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-config",
|
||||
type=str,
|
||||
default="clean"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
default="validation[:1%]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-num-samples",
|
||||
type=int,
|
||||
default=16,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="Z score threshold for a head to be selected",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--votes",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="percentage of clips that must vote for a head",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="alignment_heads.b85",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--visualize-top-k",
|
||||
type=int,
|
||||
default=32,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def collect_heads(
|
||||
model,
|
||||
tokenizer,
|
||||
clips: Sequence[Tuple[AudioInput, str]],
|
||||
threshold: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = model.device
|
||||
votes = torch.zeros(model.dims.n_text_layer, model.dims.n_text_head, device=device)
|
||||
strengths = torch.zeros_like(votes)
|
||||
|
||||
for audio_source, transcript in clips:
|
||||
waveform = pad_or_trim(_waveform_from_source(audio_source))
|
||||
mel = log_mel_spectrogram(waveform, device=device)
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
|
||||
qks = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: qks.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
for layer_idx, tensor in enumerate(qks):
|
||||
if tensor is None:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1)
|
||||
peak = tensor.max(dim=-1).values # [heads, tokens]
|
||||
strengths[layer_idx] += peak.mean(dim=-1)
|
||||
zscore = (peak - peak.mean(dim=-1, keepdim=True)) / (
|
||||
peak.std(dim=-1, keepdim=True, unbiased=False) + 1e-6
|
||||
)
|
||||
mask = (zscore > 3).any(dim=-1)
|
||||
votes[layer_idx] += mask.float()
|
||||
|
||||
votes /= len(clips)
|
||||
strengths /= len(clips)
|
||||
return votes, strengths
|
||||
|
||||
|
||||
def _select_heads_for_visualization(selection, strengths, top_k):
|
||||
selected = torch.nonzero(selection, as_tuple=False)
|
||||
if selected.numel() == 0:
|
||||
return []
|
||||
|
||||
entries = [
|
||||
(int(layer.item()), int(head.item()), float(strengths[layer, head].item()))
|
||||
for layer, head in selected
|
||||
]
|
||||
entries.sort(key=lambda item: item[2], reverse=True)
|
||||
return entries[:top_k]
|
||||
|
||||
def _extract_heatmaps(
|
||||
model,
|
||||
tokenizer,
|
||||
clip: Tuple[AudioInput, str],
|
||||
heads: Sequence[Tuple[int, int, float]],
|
||||
) -> dict:
|
||||
if not heads:
|
||||
return {}
|
||||
|
||||
target_map = {}
|
||||
for layer, head, _ in heads:
|
||||
target_map.setdefault(layer, set()).add(head)
|
||||
|
||||
waveform = pad_or_trim(_waveform_from_source(clip[0]))
|
||||
mel = log_mel_spectrogram(waveform, device=model.device)
|
||||
transcript = clip[1]
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*tokenizer.encode(transcript),
|
||||
tokenizer.eot,
|
||||
],
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, __, outputs, index=i: QKs.__setitem__(index, outputs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
model(mel.unsqueeze(0), tokens.unsqueeze(0))
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
heatmaps = {}
|
||||
for layer_idx, tensor in enumerate(QKs):
|
||||
if tensor is None or layer_idx not in target_map:
|
||||
continue
|
||||
tensor = tensor[:, :, : mel.shape[-1] // 2]
|
||||
tensor = tensor.softmax(dim=-1).cpu()
|
||||
for head_idx in target_map[layer_idx]:
|
||||
heatmaps[(layer_idx, head_idx)] = tensor[head_idx]
|
||||
|
||||
return heatmaps
|
||||
|
||||
|
||||
def _plot_heatmaps(
|
||||
heads, heatmaps, output_path):
|
||||
cols = min(3, len(heads))
|
||||
rows = math.ceil(len(heads) / cols)
|
||||
fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3.2 * rows), squeeze=False)
|
||||
|
||||
for idx, (layer, head, score) in enumerate(heads):
|
||||
ax = axes[idx // cols][idx % cols]
|
||||
mat = heatmaps.get((layer, head))
|
||||
if mat is None:
|
||||
ax.axis("off")
|
||||
continue
|
||||
im = ax.imshow(mat.to(torch.float32).numpy(), aspect="auto", origin="lower")
|
||||
ax.set_title(f"L{layer} H{head} · score {score:.2f}")
|
||||
ax.set_xlabel("time")
|
||||
ax.set_ylabel("tokens")
|
||||
|
||||
for j in range(len(heads), rows * cols):
|
||||
axes[j // cols][j % cols].axis("off")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def _dump_mask(mask: torch.Tensor, output_path: str):
|
||||
payload = mask.numpy().astype(np.bool_)
|
||||
blob = base64.b85encode(gzip.compress(payload.tobytes()))
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(blob)
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
model = load_model(args.model, device=args.device)
|
||||
model.eval()
|
||||
tokenizer = get_tokenizer(multilingual=model.is_multilingual)
|
||||
clips = load_clips(args)
|
||||
|
||||
votes, strengths = collect_heads(model, tokenizer, clips, args.threshold)
|
||||
# selection = votes > 0.5
|
||||
selection = strengths > 0.05
|
||||
_dump_mask(selection.cpu(), args.output)
|
||||
|
||||
viz_heads = _select_heads_for_visualization(selection, strengths, args.visualize_top_k)
|
||||
heatmaps = _extract_heatmaps(model, tokenizer, clips[0], viz_heads)
|
||||
_plot_heatmaps(viz_heads, heatmaps, "alignment_heads.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
39
scripts/sync_extension.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Copy core files from web directory to Chrome extension directory."""
|
||||
|
||||
import shutil
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def sync_extension_files():
|
||||
|
||||
web_dir = Path("whisperlivekit/web")
|
||||
extension_dir = Path("chrome-extension")
|
||||
|
||||
files_to_sync = [
|
||||
"live_transcription.html", "live_transcription.js", "live_transcription.css"
|
||||
]
|
||||
|
||||
svg_files = [
|
||||
"system_mode.svg",
|
||||
"light_mode.svg",
|
||||
"dark_mode.svg",
|
||||
"settings.svg"
|
||||
]
|
||||
|
||||
for file in files_to_sync:
|
||||
src_path = web_dir / file
|
||||
dest_path = extension_dir / file
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
for svg_file in svg_files:
|
||||
src_path = web_dir / "src" / svg_file
|
||||
dest_path = extension_dir / "web" / "src" / svg_file
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
sync_extension_files()
|
||||
@@ -1,738 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import numpy as np
|
||||
import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
|
||||
|
||||
import io
|
||||
import soundfile as sf
|
||||
import math
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@lru_cache
|
||||
def load_audio(fname):
|
||||
a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
|
||||
return a
|
||||
|
||||
def load_audio_chunk(fname, beg, end):
|
||||
audio = load_audio(fname)
|
||||
beg_s = int(beg*16000)
|
||||
end_s = int(end*16000)
|
||||
return audio[beg_s:end_s]
|
||||
|
||||
|
||||
# Whisper backend
|
||||
|
||||
class ASRBase:
|
||||
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when neeeded)
|
||||
|
||||
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
|
||||
self.transcribe_kargs = {}
|
||||
if lan == "auto":
|
||||
self.original_language = None
|
||||
else:
|
||||
self.original_language = lan
|
||||
|
||||
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
||||
|
||||
|
||||
def load_model(self, modelsize, cache_dir):
|
||||
raise NotImplemented("must be implemented in the child class")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
raise NotImplemented("must be implemented in the child class")
|
||||
|
||||
def use_vad(self):
|
||||
raise NotImplemented("must be implemented in the child class")
|
||||
|
||||
|
||||
class WhisperTimestampedASR(ASRBase):
|
||||
"""Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
|
||||
On the other hand, the installation for GPU could be easier.
|
||||
"""
|
||||
|
||||
sep = " "
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
import whisper
|
||||
import whisper_timestamped
|
||||
from whisper_timestamped import transcribe_timestamped
|
||||
self.transcribe_timestamped = transcribe_timestamped
|
||||
if model_dir is not None:
|
||||
logger.debug("ignoring model_dir, not implemented")
|
||||
return whisper.load_model(modelsize, download_root=cache_dir)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
result = self.transcribe_timestamped(self.model,
|
||||
audio, language=self.original_language,
|
||||
initial_prompt=init_prompt, verbose=None,
|
||||
condition_on_previous_text=True, **self.transcribe_kargs)
|
||||
return result
|
||||
|
||||
def ts_words(self,r):
|
||||
# return: transcribe result object to [(beg,end,"word1"), ...]
|
||||
o = []
|
||||
for s in r["segments"]:
|
||||
for w in s["words"]:
|
||||
t = (w["start"],w["end"],w["text"])
|
||||
o.append(t)
|
||||
return o
|
||||
|
||||
def segments_end_ts(self, res):
|
||||
return [s["end"] for s in res["segments"]]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad"] = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.transcribe_kargs["task"] = "translate"
|
||||
|
||||
|
||||
|
||||
|
||||
class FasterWhisperASR(ASRBase):
|
||||
"""Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
|
||||
"""
|
||||
|
||||
sep = ""
|
||||
|
||||
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
||||
from faster_whisper import WhisperModel
|
||||
# logging.getLogger("faster_whisper").setLevel(logger.level)
|
||||
if model_dir is not None:
|
||||
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.")
|
||||
model_size_or_path = model_dir
|
||||
elif modelsize is not None:
|
||||
model_size_or_path = modelsize
|
||||
else:
|
||||
raise ValueError("modelsize or model_dir parameter must be set")
|
||||
|
||||
|
||||
# this worked fast and reliably on NVIDIA L40
|
||||
model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
|
||||
|
||||
# or run on GPU with INT8
|
||||
# tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
|
||||
#model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
|
||||
|
||||
# or run on CPU with INT8
|
||||
# tested: works, but slow, appx 10-times than cuda FP16
|
||||
# model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
|
||||
return model
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
|
||||
# tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
|
||||
segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs)
|
||||
#print(info) # info contains language detection result
|
||||
|
||||
return list(segments)
|
||||
|
||||
def ts_words(self, segments):
|
||||
o = []
|
||||
for segment in segments:
|
||||
for word in segment.words:
|
||||
# not stripping the spaces -- should not be merged with them!
|
||||
w = word.word
|
||||
t = (word.start, word.end, w)
|
||||
o.append(t)
|
||||
return o
|
||||
|
||||
def segments_end_ts(self, res):
|
||||
return [s.end for s in res]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.transcribe_kargs["task"] = "translate"
|
||||
|
||||
|
||||
class OpenaiApiASR(ASRBase):
|
||||
"""Uses OpenAI's Whisper API for audio transcription."""
|
||||
|
||||
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
|
||||
self.modelname = "whisper-1"
|
||||
self.original_language = None if lan == "auto" else lan # ISO-639-1 language code
|
||||
self.response_format = "verbose_json"
|
||||
self.temperature = temperature
|
||||
|
||||
self.load_model()
|
||||
|
||||
self.use_vad_opt = False
|
||||
|
||||
# reset the task in set_translate_task
|
||||
self.task = "transcribe"
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI()
|
||||
|
||||
self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost
|
||||
|
||||
|
||||
def ts_words(self, segments):
|
||||
no_speech_segments = []
|
||||
if self.use_vad_opt:
|
||||
for segment in segments.segments:
|
||||
# TODO: threshold can be set from outside
|
||||
if segment["no_speech_prob"] > 0.8:
|
||||
no_speech_segments.append((segment.get("start"), segment.get("end")))
|
||||
|
||||
o = []
|
||||
for word in segments.words:
|
||||
start = word.get("start")
|
||||
end = word.get("end")
|
||||
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
||||
# print("Skipping word", word.get("word"), "because it's in a no-speech segment")
|
||||
continue
|
||||
o.append((start, end, word.get("word")))
|
||||
return o
|
||||
|
||||
|
||||
def segments_end_ts(self, res):
|
||||
return [s["end"] for s in res.words]
|
||||
|
||||
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
||||
# Write the audio data to a buffer
|
||||
buffer = io.BytesIO()
|
||||
buffer.name = "temp.wav"
|
||||
sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
|
||||
buffer.seek(0) # Reset buffer's position to the beginning
|
||||
|
||||
self.transcribed_seconds += math.ceil(len(audio_data)/16000) # it rounds up to the whole seconds
|
||||
|
||||
params = {
|
||||
"model": self.modelname,
|
||||
"file": buffer,
|
||||
"response_format": self.response_format,
|
||||
"temperature": self.temperature,
|
||||
"timestamp_granularities": ["word", "segment"]
|
||||
}
|
||||
if self.task != "translate" and self.original_language:
|
||||
params["language"] = self.original_language
|
||||
if prompt:
|
||||
params["prompt"] = prompt
|
||||
|
||||
if self.task == "translate":
|
||||
proc = self.client.audio.translations
|
||||
else:
|
||||
proc = self.client.audio.transcriptions
|
||||
|
||||
# Process transcription/translation
|
||||
transcript = proc.create(**params)
|
||||
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||
|
||||
return transcript
|
||||
|
||||
def use_vad(self):
|
||||
self.use_vad_opt = True
|
||||
|
||||
def set_translate_task(self):
|
||||
self.task = "translate"
|
||||
|
||||
|
||||
|
||||
|
||||
class HypothesisBuffer:
|
||||
|
||||
def __init__(self, logfile=sys.stderr):
|
||||
self.commited_in_buffer = []
|
||||
self.buffer = []
|
||||
self.new = []
|
||||
|
||||
self.last_commited_time = 0
|
||||
self.last_commited_word = None
|
||||
|
||||
self.logfile = logfile
|
||||
|
||||
def insert(self, new, offset):
|
||||
# compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
|
||||
# the new tail is added to self.new
|
||||
|
||||
new = [(a+offset,b+offset,t) for a,b,t in new]
|
||||
self.new = [(a,b,t) for a,b,t in new if a > self.last_commited_time-0.1]
|
||||
|
||||
if len(self.new) >= 1:
|
||||
a,b,t = self.new[0]
|
||||
if abs(a - self.last_commited_time) < 1:
|
||||
if self.commited_in_buffer:
|
||||
# it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
|
||||
cn = len(self.commited_in_buffer)
|
||||
nn = len(self.new)
|
||||
for i in range(1,min(min(cn,nn),5)+1): # 5 is the maximum
|
||||
c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
|
||||
tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
|
||||
if c == tail:
|
||||
words = []
|
||||
for j in range(i):
|
||||
words.append(repr(self.new.pop(0)))
|
||||
words_msg = " ".join(words)
|
||||
logger.debug(f"removing last {i} words: {words_msg}")
|
||||
break
|
||||
|
||||
def flush(self):
|
||||
# returns commited chunk = the longest common prefix of 2 last inserts.
|
||||
|
||||
commit = []
|
||||
while self.new:
|
||||
na, nb, nt = self.new[0]
|
||||
|
||||
if len(self.buffer) == 0:
|
||||
break
|
||||
|
||||
if nt == self.buffer[0][2]:
|
||||
commit.append((na,nb,nt))
|
||||
self.last_commited_word = nt
|
||||
self.last_commited_time = nb
|
||||
self.buffer.pop(0)
|
||||
self.new.pop(0)
|
||||
else:
|
||||
break
|
||||
self.buffer = self.new
|
||||
self.new = []
|
||||
self.commited_in_buffer.extend(commit)
|
||||
return commit
|
||||
|
||||
def pop_commited(self, time):
|
||||
while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
|
||||
self.commited_in_buffer.pop(0)
|
||||
|
||||
def complete(self):
|
||||
return self.buffer
|
||||
|
||||
class OnlineASRProcessor:
|
||||
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(self, asr, tokenizer=None, buffer_trimming=("segment", 15), logfile=sys.stderr):
|
||||
"""asr: WhisperASR object
|
||||
tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
|
||||
("segment", 15)
|
||||
buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
|
||||
logfile: where to store the log.
|
||||
"""
|
||||
self.asr = asr
|
||||
self.tokenizer = tokenizer
|
||||
self.logfile = logfile
|
||||
|
||||
self.init()
|
||||
|
||||
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
||||
|
||||
def init(self):
|
||||
"""run this when starting or restarting processing"""
|
||||
self.audio_buffer = np.array([],dtype=np.float32)
|
||||
self.buffer_time_offset = 0
|
||||
|
||||
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
|
||||
self.commited = []
|
||||
|
||||
def insert_audio_chunk(self, audio):
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
def prompt(self):
|
||||
"""Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
|
||||
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
|
||||
"""
|
||||
k = max(0,len(self.commited)-1)
|
||||
while k > 0 and self.commited[k-1][1] > self.buffer_time_offset:
|
||||
k -= 1
|
||||
|
||||
p = self.commited[:k]
|
||||
p = [t for _,_,t in p]
|
||||
prompt = []
|
||||
l = 0
|
||||
while p and l < 200: # 200 characters prompt size
|
||||
x = p.pop(-1)
|
||||
l += len(x)+1
|
||||
prompt.append(x)
|
||||
non_prompt = self.commited[k:]
|
||||
return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(t for _,_,t in non_prompt)
|
||||
|
||||
def process_iter(self):
|
||||
"""Runs on the current audio buffer.
|
||||
Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
|
||||
The non-emty text is confirmed (committed) partial transcript.
|
||||
"""
|
||||
|
||||
prompt, non_prompt = self.prompt()
|
||||
logger.debug(f"PROMPT: {prompt}")
|
||||
logger.debug(f"CONTEXT: {non_prompt}")
|
||||
logger.debug(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}")
|
||||
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
|
||||
|
||||
# transform to [(beg,end,"word1"), ...]
|
||||
tsw = self.asr.ts_words(res)
|
||||
|
||||
self.transcript_buffer.insert(tsw, self.buffer_time_offset)
|
||||
o = self.transcript_buffer.flush()
|
||||
self.commited.extend(o)
|
||||
completed = self.to_flush(o)
|
||||
logger.debug(f">>>>COMPLETE NOW: {completed}")
|
||||
the_rest = self.to_flush(self.transcript_buffer.complete())
|
||||
logger.debug(f"INCOMPLETE: {the_rest}")
|
||||
|
||||
# there is a newly confirmed text
|
||||
|
||||
if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
|
||||
if len(self.audio_buffer)/self.SAMPLING_RATE > self.buffer_trimming_sec: # longer than this
|
||||
self.chunk_completed_sentence()
|
||||
|
||||
|
||||
if self.buffer_trimming_way == "segment":
|
||||
s = self.buffer_trimming_sec # trim the completed segments longer than s,
|
||||
else:
|
||||
s = 30 # if the audio buffer is longer than 30s, trim it
|
||||
|
||||
if len(self.audio_buffer)/self.SAMPLING_RATE > s:
|
||||
self.chunk_completed_segment(res)
|
||||
|
||||
# alternative: on any word
|
||||
#l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
|
||||
# let's find commited word that is less
|
||||
#k = len(self.commited)-1
|
||||
#while k>0 and self.commited[k][1] > l:
|
||||
# k -= 1
|
||||
#t = self.commited[k][1]
|
||||
logger.debug("chunking segment")
|
||||
#self.chunk_at(t)
|
||||
|
||||
logger.debug(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}")
|
||||
return self.to_flush(o)
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
if self.commited == []: return
|
||||
logger.debug(self.commited)
|
||||
sents = self.words_to_sentences(self.commited)
|
||||
for s in sents:
|
||||
logger.debug(f"\t\tSENT: {s}")
|
||||
if len(sents) < 2:
|
||||
return
|
||||
while len(sents) > 2:
|
||||
sents.pop(0)
|
||||
# we will continue with audio processing at this timestamp
|
||||
chunk_at = sents[-2][1]
|
||||
|
||||
logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
|
||||
self.chunk_at(chunk_at)
|
||||
|
||||
def chunk_completed_segment(self, res):
|
||||
if self.commited == []: return
|
||||
|
||||
ends = self.asr.segments_end_ts(res)
|
||||
|
||||
t = self.commited[-1][1]
|
||||
|
||||
if len(ends) > 1:
|
||||
|
||||
e = ends[-2]+self.buffer_time_offset
|
||||
while len(ends) > 2 and e > t:
|
||||
ends.pop(-1)
|
||||
e = ends[-2]+self.buffer_time_offset
|
||||
if e <= t:
|
||||
logger.debug(f"--- segment chunked at {e:2.2f}")
|
||||
self.chunk_at(e)
|
||||
else:
|
||||
logger.debug(f"--- last segment not within commited area")
|
||||
else:
|
||||
logger.debug(f"--- not enough segments to chunk")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def chunk_at(self, time):
|
||||
"""trims the hypothesis and audio buffer at "time"
|
||||
"""
|
||||
self.transcript_buffer.pop_commited(time)
|
||||
cut_seconds = time - self.buffer_time_offset
|
||||
self.audio_buffer = self.audio_buffer[int(cut_seconds*self.SAMPLING_RATE):]
|
||||
self.buffer_time_offset = time
|
||||
|
||||
def words_to_sentences(self, words):
|
||||
"""Uses self.tokenizer for sentence segmentation of words.
|
||||
Returns: [(beg,end,"sentence 1"),...]
|
||||
"""
|
||||
|
||||
cwords = [w for w in words]
|
||||
t = " ".join(o[2] for o in cwords)
|
||||
s = self.tokenizer.split(t)
|
||||
out = []
|
||||
while s:
|
||||
beg = None
|
||||
end = None
|
||||
sent = s.pop(0).strip()
|
||||
fsent = sent
|
||||
while cwords:
|
||||
b,e,w = cwords.pop(0)
|
||||
w = w.strip()
|
||||
if beg is None and sent.startswith(w):
|
||||
beg = b
|
||||
elif end is None and sent == w:
|
||||
end = e
|
||||
out.append((beg,end,fsent))
|
||||
break
|
||||
sent = sent[len(w):].strip()
|
||||
return out
|
||||
|
||||
def finish(self):
|
||||
"""Flush the incomplete text when the whole processing ends.
|
||||
Returns: the same format as self.process_iter()
|
||||
"""
|
||||
o = self.transcript_buffer.complete()
|
||||
f = self.to_flush(o)
|
||||
logger.debug("last, noncommited: {f}")
|
||||
return f
|
||||
|
||||
|
||||
def to_flush(self, sents, sep=None, offset=0, ):
|
||||
# concatenates the timestamped words or sentences into one sequence that is flushed in one line
|
||||
# sents: [(beg1, end1, "sentence1"), ...] or [] if empty
|
||||
# return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
|
||||
if sep is None:
|
||||
sep = self.asr.sep
|
||||
t = sep.join(s[2] for s in sents)
|
||||
if len(sents) == 0:
|
||||
b = None
|
||||
e = None
|
||||
else:
|
||||
b = offset + sents[0][0]
|
||||
e = offset + sents[-1][1]
|
||||
return (b,e,t)
|
||||
|
||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(",")
|
||||
|
||||
def create_tokenizer(lan):
|
||||
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
||||
|
||||
assert lan in WHISPER_LANG_CODES, "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
||||
|
||||
if lan == "uk":
|
||||
import tokenize_uk
|
||||
class UkrainianTokenizer:
|
||||
def split(self, text):
|
||||
return tokenize_uk.tokenize_sents(text)
|
||||
return UkrainianTokenizer()
|
||||
|
||||
# supported by fast-mosestokenizer
|
||||
if lan in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split():
|
||||
from mosestokenizer import MosesTokenizer
|
||||
return MosesTokenizer(lan)
|
||||
|
||||
# the following languages are in Whisper, but not in wtpsplit:
|
||||
if lan in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split():
|
||||
logger.debug(f"{lan} code is not supported by wtpsplit. Going to use None lang_code option.")
|
||||
lan = None
|
||||
|
||||
from wtpsplit import WtP
|
||||
# downloads the model from huggingface on the first use
|
||||
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
||||
class WtPtok:
|
||||
def split(self, sent):
|
||||
return wtp.split(sent, lang_code=lan)
|
||||
return WtPtok()
|
||||
|
||||
|
||||
def add_shared_args(parser):
|
||||
"""shared args for simulation (this entry point) and server
|
||||
parser: argparse.ArgumentParser object
|
||||
"""
|
||||
parser.add_argument('--min-chunk-size', type=float, default=1.0, help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.')
|
||||
parser.add_argument('--model', type=str, default='large-v2', choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large".split(","),help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.")
|
||||
parser.add_argument('--model_cache_dir', type=str, default=None, help="Overriding the default model cache dir where models downloaded from the hub are saved")
|
||||
parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
|
||||
parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
|
||||
parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
|
||||
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.')
|
||||
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
||||
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
|
||||
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
|
||||
parser.add_argument("-l", "--log-level", dest="log_level", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help="Set the log level", default='DEBUG')
|
||||
|
||||
def asr_factory(args, logfile=sys.stderr):
|
||||
"""
|
||||
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
||||
"""
|
||||
backend = args.backend
|
||||
if backend == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=args.lan)
|
||||
else:
|
||||
if backend == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
else:
|
||||
asr_cls = WhisperTimestampedASR
|
||||
|
||||
# Only for FasterWhisperASR and WhisperTimestampedASR
|
||||
size = args.model
|
||||
t = time.time()
|
||||
logger.info(f"Loading Whisper {size} model for {args.lan}...")
|
||||
asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
|
||||
# Apply common configurations
|
||||
if getattr(args, 'vad', False): # Checks if VAD argument is present and True
|
||||
logger.info("Setting VAD filter")
|
||||
asr.use_vad()
|
||||
|
||||
language = args.lan
|
||||
if args.task == "translate":
|
||||
asr.set_translate_task()
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = language # Whisper transcribes in this language
|
||||
|
||||
# Create the tokenizer
|
||||
if args.buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
# Create the OnlineASRProcessor
|
||||
online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
|
||||
|
||||
return asr, online
|
||||
|
||||
def set_logging(args,logger,other="_server"):
|
||||
logging.basicConfig(#format='%(name)s
|
||||
format='%(levelname)s\t%(message)s')
|
||||
logger.setLevel(args.log_level)
|
||||
logging.getLogger("whisper_online"+other).setLevel(args.log_level)
|
||||
# logging.getLogger("whisper_online_server").setLevel(args.log_level)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('audio_path', type=str, help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
|
||||
add_shared_args(parser)
|
||||
parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
|
||||
parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
|
||||
parser.add_argument('--comp_unaware', action="store_true", default=False, help='Computationally unaware simulation.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# reset to store stderr to different file stream, e.g. open(os.devnull,"w")
|
||||
logfile = sys.stderr
|
||||
|
||||
if args.offline and args.comp_unaware:
|
||||
logger.error("No or one option from --offline and --comp_unaware are available, not both. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
# if args.log_level:
|
||||
# logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
|
||||
# level=getattr(logging, args.log_level))
|
||||
|
||||
set_logging(args,logger)
|
||||
|
||||
audio_path = args.audio_path
|
||||
|
||||
SAMPLING_RATE = 16000
|
||||
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
||||
logger.info("Audio duration is: %2.2f seconds" % duration)
|
||||
|
||||
asr, online = asr_factory(args, logfile=logfile)
|
||||
min_chunk = args.min_chunk_size
|
||||
|
||||
# load the audio into the LRU cache before we start the timer
|
||||
a = load_audio_chunk(audio_path,0,1)
|
||||
|
||||
# warm up the ASR because the very first transcribe takes much more time than the other
|
||||
asr.transcribe(a)
|
||||
|
||||
beg = args.start_at
|
||||
start = time.time()-beg
|
||||
|
||||
def output_transcript(o, now=None):
|
||||
# output format in stdout is like:
|
||||
# 4186.3606 0 1720 Takhle to je
|
||||
# - the first three words are:
|
||||
# - emission time from beginning of processing, in milliseconds
|
||||
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
|
||||
# - the next words: segment transcript
|
||||
if now is None:
|
||||
now = time.time()-start
|
||||
if o[0] is not None:
|
||||
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True)
|
||||
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True)
|
||||
else:
|
||||
# No text, so no output
|
||||
pass
|
||||
|
||||
if args.offline: ## offline mode processing (for testing/debugging)
|
||||
a = load_audio(audio_path)
|
||||
online.insert_audio_chunk(a)
|
||||
try:
|
||||
o = online.process_iter()
|
||||
except AssertionError as e:
|
||||
log.error(f"assertion error: {repr(e)}")
|
||||
else:
|
||||
output_transcript(o)
|
||||
now = None
|
||||
elif args.comp_unaware: # computational unaware mode
|
||||
end = beg + min_chunk
|
||||
while True:
|
||||
a = load_audio_chunk(audio_path,beg,end)
|
||||
online.insert_audio_chunk(a)
|
||||
try:
|
||||
o = online.process_iter()
|
||||
except AssertionError as e:
|
||||
logger.error(f"assertion error: {repr(e)}")
|
||||
pass
|
||||
else:
|
||||
output_transcript(o, now=end)
|
||||
|
||||
logger.debug(f"## last processed {end:.2f}s")
|
||||
|
||||
if end >= duration:
|
||||
break
|
||||
|
||||
beg = end
|
||||
|
||||
if end + min_chunk > duration:
|
||||
end = duration
|
||||
else:
|
||||
end += min_chunk
|
||||
now = duration
|
||||
|
||||
else: # online = simultaneous mode
|
||||
end = 0
|
||||
while True:
|
||||
now = time.time() - start
|
||||
if now < end+min_chunk:
|
||||
time.sleep(min_chunk+end-now)
|
||||
end = time.time() - start
|
||||
a = load_audio_chunk(audio_path,beg,end)
|
||||
beg = end
|
||||
online.insert_audio_chunk(a)
|
||||
|
||||
try:
|
||||
o = online.process_iter()
|
||||
except AssertionError as e:
|
||||
logger.error(f"assertion error: {e}")
|
||||
pass
|
||||
else:
|
||||
output_transcript(o)
|
||||
now = time.time() - start
|
||||
logger.debug(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now-end:.2f}")
|
||||
|
||||
if end >= duration:
|
||||
break
|
||||
now = None
|
||||
|
||||
o = online.finish()
|
||||
output_transcript(o, now=now)
|
||||
@@ -1,175 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
from whisper_online import *
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# server options
|
||||
parser.add_argument("--host", type=str, default='localhost')
|
||||
parser.add_argument("--port", type=int, default=43007)
|
||||
parser.add_argument("--warmup-file", type=str, dest="warmup_file",
|
||||
help="The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast. It can be e.g. https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav .")
|
||||
|
||||
|
||||
# options from whisper_online
|
||||
add_shared_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
set_logging(args,logger,other="")
|
||||
|
||||
# setting whisper object by args
|
||||
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
size = args.model
|
||||
language = args.lan
|
||||
asr, online = asr_factory(args)
|
||||
min_chunk = args.min_chunk_size
|
||||
|
||||
# warm up the ASR because the very first transcribe takes more time than the others.
|
||||
# Test results in https://github.com/ufal/whisper_streaming/pull/81
|
||||
msg = "Whisper is not warmed up. The first chunk processing may take longer."
|
||||
if args.warmup_file:
|
||||
if os.path.isfile(args.warmup_file):
|
||||
a = load_audio_chunk(args.warmup_file,0,1)
|
||||
asr.transcribe(a)
|
||||
logger.info("Whisper is warmed up.")
|
||||
else:
|
||||
logger.critical("The warm up file is not available. "+msg)
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.warning(msg)
|
||||
|
||||
|
||||
######### Server objects
|
||||
|
||||
import line_packet
|
||||
import socket
|
||||
|
||||
class Connection:
|
||||
'''it wraps conn object'''
|
||||
PACKET_SIZE = 65536
|
||||
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
self.last_line = ""
|
||||
|
||||
self.conn.setblocking(True)
|
||||
|
||||
def send(self, line):
|
||||
'''it doesn't send the same line twice, because it was problematic in online-text-flow-events'''
|
||||
if line == self.last_line:
|
||||
return
|
||||
line_packet.send_one_line(self.conn, line)
|
||||
self.last_line = line
|
||||
|
||||
def receive_lines(self):
|
||||
in_line = line_packet.receive_lines(self.conn)
|
||||
return in_line
|
||||
|
||||
def non_blocking_receive_audio(self):
|
||||
r = self.conn.recv(self.PACKET_SIZE)
|
||||
return r
|
||||
|
||||
|
||||
import io
|
||||
import soundfile
|
||||
|
||||
# wraps socket and ASR object, and serves one client connection.
|
||||
# next client should be served by a new instance of this object
|
||||
class ServerProcessor:
|
||||
|
||||
def __init__(self, c, online_asr_proc, min_chunk):
|
||||
self.connection = c
|
||||
self.online_asr_proc = online_asr_proc
|
||||
self.min_chunk = min_chunk
|
||||
|
||||
self.last_end = None
|
||||
|
||||
def receive_audio_chunk(self):
|
||||
# receive all audio that is available by this time
|
||||
# blocks operation if less than self.min_chunk seconds is available
|
||||
# unblocks if connection is closed or a chunk is available
|
||||
out = []
|
||||
while sum(len(x) for x in out) < self.min_chunk*SAMPLING_RATE:
|
||||
raw_bytes = self.connection.non_blocking_receive_audio()
|
||||
if not raw_bytes:
|
||||
break
|
||||
sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
|
||||
audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32)
|
||||
out.append(audio)
|
||||
if not out:
|
||||
return None
|
||||
return np.concatenate(out)
|
||||
|
||||
def format_output_transcript(self,o):
|
||||
# output format in stdout is like:
|
||||
# 0 1720 Takhle to je
|
||||
# - the first two words are:
|
||||
# - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
|
||||
# - the next words: segment transcript
|
||||
|
||||
# This function differs from whisper_online.output_transcript in the following:
|
||||
# succeeding [beg,end] intervals are not overlapping because ELITR protocol (implemented in online-text-flow events) requires it.
|
||||
# Therefore, beg, is max of previous end and current beg outputed by Whisper.
|
||||
# Usually it differs negligibly, by appx 20 ms.
|
||||
|
||||
if o[0] is not None:
|
||||
beg, end = o[0]*1000,o[1]*1000
|
||||
if self.last_end is not None:
|
||||
beg = max(beg, self.last_end)
|
||||
|
||||
self.last_end = end
|
||||
print("%1.0f %1.0f %s" % (beg,end,o[2]),flush=True,file=sys.stderr)
|
||||
return "%1.0f %1.0f %s" % (beg,end,o[2])
|
||||
else:
|
||||
logger.debug("No text in this segment")
|
||||
return None
|
||||
|
||||
def send_result(self, o):
|
||||
msg = self.format_output_transcript(o)
|
||||
if msg is not None:
|
||||
self.connection.send(msg)
|
||||
|
||||
def process(self):
|
||||
# handle one client connection
|
||||
self.online_asr_proc.init()
|
||||
while True:
|
||||
a = self.receive_audio_chunk()
|
||||
if a is None:
|
||||
break
|
||||
self.online_asr_proc.insert_audio_chunk(a)
|
||||
o = online.process_iter()
|
||||
try:
|
||||
self.send_result(o)
|
||||
except BrokenPipeError:
|
||||
logger.info("broken pipe -- connection closed?")
|
||||
break
|
||||
|
||||
# o = online.finish() # this should be working
|
||||
# self.send_result(o)
|
||||
|
||||
|
||||
|
||||
# server loop
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind((args.host, args.port))
|
||||
s.listen(1)
|
||||
logger.info('Listening on'+str((args.host, args.port)))
|
||||
while True:
|
||||
conn, addr = s.accept()
|
||||
logger.info('Connected to client on {}'.format(addr))
|
||||
connection = Connection(conn)
|
||||
proc = ServerProcessor(connection, online, min_chunk)
|
||||
proc.process()
|
||||
conn.close()
|
||||
logger.info('Connection to client closed')
|
||||
logger.info('Connection closed, terminating.')
|
||||
168
whisperlivekit/TokensAlignment.py
Normal file
@@ -0,0 +1,168 @@
|
||||
class TokensAlignment:
|
||||
|
||||
def __init__(self, state_light, silence=None, args=None):
|
||||
self.state_light = state_light
|
||||
self.silence = silence
|
||||
self.args = args
|
||||
|
||||
self._tokens_index = 0
|
||||
self._diarization_index = 0
|
||||
self._translation_index = 0
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
|
||||
def compute_punctuations_segments(self):
|
||||
punctuations_breaks = []
|
||||
new_tokens = self.state.tokens[self.state.last_validated_token:]
|
||||
for i in range(len(new_tokens)):
|
||||
token = new_tokens[i]
|
||||
if token.is_punctuation():
|
||||
punctuations_breaks.append({
|
||||
'token_index': i,
|
||||
'token': token,
|
||||
'start': token.start,
|
||||
'end': token.end,
|
||||
})
|
||||
punctuations_segments = []
|
||||
for i, break_info in enumerate(punctuations_breaks):
|
||||
start = punctuations_breaks[i - 1]['end'] if i > 0 else 0.0
|
||||
end = break_info['end']
|
||||
punctuations_segments.append({
|
||||
'start': start,
|
||||
'end': end,
|
||||
'token_index': break_info['token_index'],
|
||||
'token': break_info['token']
|
||||
})
|
||||
return punctuations_segments
|
||||
|
||||
def concatenate_diar_segments(self):
|
||||
diarization_segments = self.state.diarization_segments
|
||||
|
||||
if __name__ == "__main__":
|
||||
from whisperlivekit.timed_objects import State, ASRToken, SpeakerSegment, Transcript, Silence
|
||||
|
||||
# Reconstruct the state from the backup data
|
||||
tokens = [
|
||||
ASRToken(start=1.38, end=1.48, text=' The'),
|
||||
ASRToken(start=1.42, end=1.52, text=' description'),
|
||||
ASRToken(start=1.82, end=1.92, text=' technology'),
|
||||
ASRToken(start=2.54, end=2.64, text=' has'),
|
||||
ASRToken(start=2.7, end=2.8, text=' improved'),
|
||||
ASRToken(start=3.24, end=3.34, text=' so'),
|
||||
ASRToken(start=3.66, end=3.76, text=' much'),
|
||||
ASRToken(start=4.02, end=4.12, text=' in'),
|
||||
ASRToken(start=4.08, end=4.18, text=' the'),
|
||||
ASRToken(start=4.26, end=4.36, text=' past'),
|
||||
ASRToken(start=4.48, end=4.58, text=' few'),
|
||||
ASRToken(start=4.76, end=4.86, text=' years'),
|
||||
ASRToken(start=5.76, end=5.86, text='.'),
|
||||
ASRToken(start=5.72, end=5.82, text=' Have'),
|
||||
ASRToken(start=5.92, end=6.02, text=' you'),
|
||||
ASRToken(start=6.08, end=6.18, text=' noticed'),
|
||||
ASRToken(start=6.52, end=6.62, text=' how'),
|
||||
ASRToken(start=6.8, end=6.9, text=' accurate'),
|
||||
ASRToken(start=7.46, end=7.56, text=' real'),
|
||||
ASRToken(start=7.72, end=7.82, text='-time'),
|
||||
ASRToken(start=8.06, end=8.16, text=' speech'),
|
||||
ASRToken(start=8.48, end=8.58, text=' to'),
|
||||
ASRToken(start=8.68, end=8.78, text=' text'),
|
||||
ASRToken(start=9.0, end=9.1, text=' is'),
|
||||
ASRToken(start=9.24, end=9.34, text=' now'),
|
||||
ASRToken(start=9.82, end=9.92, text='?'),
|
||||
ASRToken(start=9.86, end=9.96, text=' Absolutely'),
|
||||
ASRToken(start=11.26, end=11.36, text='.'),
|
||||
ASRToken(start=11.36, end=11.46, text=' I'),
|
||||
ASRToken(start=11.58, end=11.68, text=' use'),
|
||||
ASRToken(start=11.78, end=11.88, text=' it'),
|
||||
ASRToken(start=11.94, end=12.04, text=' all'),
|
||||
ASRToken(start=12.08, end=12.18, text=' the'),
|
||||
ASRToken(start=12.32, end=12.42, text=' time'),
|
||||
ASRToken(start=12.58, end=12.68, text=' for'),
|
||||
ASRToken(start=12.78, end=12.88, text=' taking'),
|
||||
ASRToken(start=13.14, end=13.24, text=' notes'),
|
||||
ASRToken(start=13.4, end=13.5, text=' during'),
|
||||
ASRToken(start=13.78, end=13.88, text=' meetings'),
|
||||
ASRToken(start=14.6, end=14.7, text='.'),
|
||||
ASRToken(start=14.82, end=14.92, text=' It'),
|
||||
ASRToken(start=14.92, end=15.02, text="'s"),
|
||||
ASRToken(start=15.04, end=15.14, text=' amazing'),
|
||||
ASRToken(start=15.5, end=15.6, text=' how'),
|
||||
ASRToken(start=15.66, end=15.76, text=' it'),
|
||||
ASRToken(start=15.8, end=15.9, text=' can'),
|
||||
ASRToken(start=15.96, end=16.06, text=' recognize'),
|
||||
ASRToken(start=16.58, end=16.68, text=' different'),
|
||||
ASRToken(start=16.94, end=17.04, text=' speakers'),
|
||||
ASRToken(start=17.82, end=17.92, text=' and'),
|
||||
ASRToken(start=18.0, end=18.1, text=' even'),
|
||||
ASRToken(start=18.42, end=18.52, text=' add'),
|
||||
ASRToken(start=18.74, end=18.84, text=' punct'),
|
||||
ASRToken(start=19.02, end=19.12, text='uation'),
|
||||
ASRToken(start=19.68, end=19.78, text='.'),
|
||||
ASRToken(start=20.04, end=20.14, text=' Yeah'),
|
||||
ASRToken(start=20.5, end=20.6, text=','),
|
||||
ASRToken(start=20.6, end=20.7, text=' but'),
|
||||
ASRToken(start=20.76, end=20.86, text=' sometimes'),
|
||||
ASRToken(start=21.42, end=21.52, text=' noise'),
|
||||
ASRToken(start=21.82, end=21.92, text=' can'),
|
||||
ASRToken(start=22.08, end=22.18, text=' still'),
|
||||
ASRToken(start=22.38, end=22.48, text=' cause'),
|
||||
ASRToken(start=22.72, end=22.82, text=' mistakes'),
|
||||
ASRToken(start=23.74, end=23.84, text='.'),
|
||||
ASRToken(start=23.96, end=24.06, text=' Does'),
|
||||
ASRToken(start=24.16, end=24.26, text=' this'),
|
||||
ASRToken(start=24.4, end=24.5, text=' system'),
|
||||
ASRToken(start=24.76, end=24.86, text=' handle'),
|
||||
ASRToken(start=25.12, end=25.22, text=' that'),
|
||||
ASRToken(start=25.38, end=25.48, text=' well'),
|
||||
ASRToken(start=25.68, end=25.78, text='?'),
|
||||
ASRToken(start=26.4, end=26.5, text=' It'),
|
||||
ASRToken(start=26.5, end=26.6, text=' does'),
|
||||
ASRToken(start=26.7, end=26.8, text=' a'),
|
||||
ASRToken(start=27.08, end=27.18, text=' pretty'),
|
||||
ASRToken(start=27.12, end=27.22, text=' good'),
|
||||
ASRToken(start=27.34, end=27.44, text=' job'),
|
||||
ASRToken(start=27.64, end=27.74, text=' filtering'),
|
||||
ASRToken(start=28.1, end=28.2, text=' noise'),
|
||||
ASRToken(start=28.64, end=28.74, text=','),
|
||||
ASRToken(start=28.78, end=28.88, text=' especially'),
|
||||
ASRToken(start=29.3, end=29.4, text=' with'),
|
||||
ASRToken(start=29.51, end=29.61, text=' models'),
|
||||
ASRToken(start=29.99, end=30.09, text=' that'),
|
||||
ASRToken(start=30.21, end=30.31, text=' use'),
|
||||
ASRToken(start=30.51, end=30.61, text=' voice'),
|
||||
ASRToken(start=30.83, end=30.93, text=' activity'),
|
||||
]
|
||||
|
||||
diarization_segments = [
|
||||
SpeakerSegment(start=1.3255040645599365, end=4.3255040645599365, speaker=0),
|
||||
SpeakerSegment(start=4.806154012680054, end=9.806154012680054, speaker=0),
|
||||
SpeakerSegment(start=9.806154012680054, end=10.806154012680054, speaker=1),
|
||||
SpeakerSegment(start=11.168735027313232, end=14.168735027313232, speaker=1),
|
||||
SpeakerSegment(start=14.41029405593872, end=17.41029405593872, speaker=1),
|
||||
SpeakerSegment(start=17.52983808517456, end=19.52983808517456, speaker=1),
|
||||
SpeakerSegment(start=19.64953374862671, end=20.066200415293377, speaker=1),
|
||||
SpeakerSegment(start=20.066200415293377, end=22.64953374862671, speaker=2),
|
||||
SpeakerSegment(start=23.012792587280273, end=25.012792587280273, speaker=2),
|
||||
SpeakerSegment(start=25.495875597000122, end=26.41254226366679, speaker=2),
|
||||
SpeakerSegment(start=26.41254226366679, end=30.495875597000122, speaker=0),
|
||||
]
|
||||
|
||||
state = State(
|
||||
tokens=tokens,
|
||||
last_validated_token=72,
|
||||
last_speaker=-1,
|
||||
last_punctuation_index=71,
|
||||
translation_validated_segments=[],
|
||||
buffer_translation=Transcript(start=0, end=0, speaker=-1),
|
||||
buffer_transcription=Transcript(start=None, end=None, speaker=-1),
|
||||
diarization_segments=diarization_segments,
|
||||
end_buffer=31.21587559700018,
|
||||
end_attributed_speaker=30.495875597000122,
|
||||
remaining_time_transcription=0.4,
|
||||
remaining_time_diarization=0.7,
|
||||
beg_loop=1763627603.968919
|
||||
)
|
||||
|
||||
alignment = TokensAlignment(state)
|
||||
13
whisperlivekit/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .audio_processor import AudioProcessor
|
||||
from .core import TranscriptionEngine
|
||||
from .parse_args import parse_args
|
||||
from .web.web_interface import get_web_interface_html, get_inline_ui_html
|
||||
|
||||
__all__ = [
|
||||
"TranscriptionEngine",
|
||||
"AudioProcessor",
|
||||
"parse_args",
|
||||
"get_web_interface_html",
|
||||
"get_inline_ui_html",
|
||||
"download_simulstreaming_backend",
|
||||
]
|
||||
661
whisperlivekit/audio_processor.py
Normal file
@@ -0,0 +1,661 @@
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from time import time, sleep
|
||||
import math
|
||||
import logging
|
||||
import traceback
|
||||
from whisperlivekit.timed_objects import ASRToken, Silence, Line, FrontData, State, StateLight, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.core import TranscriptionEngine, online_factory, online_diarization_factory, online_translation_factory
|
||||
from whisperlivekit.silero_vad_iterator import FixedVADIterator
|
||||
from whisperlivekit.results_formater import format_output
|
||||
from whisperlivekit.ffmpeg_manager import FFmpegManager, FFmpegState
|
||||
from whisperlivekit.TokensAlignment import TokensAlignment
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
SENTINEL = object() # unique sentinel object for end of stream marker
|
||||
|
||||
def cut_at(cumulative_pcm, cut_sec):
|
||||
cumulative_len = 0
|
||||
cut_sample = int(cut_sec * 16000)
|
||||
|
||||
for ind, pcm_array in enumerate(cumulative_pcm):
|
||||
if (cumulative_len + len(pcm_array)) >= cut_sample:
|
||||
cut_chunk = cut_sample - cumulative_len
|
||||
before = np.concatenate(cumulative_pcm[:ind] + [cumulative_pcm[ind][:cut_chunk]])
|
||||
after = [cumulative_pcm[ind][cut_chunk:]] + cumulative_pcm[ind+1:]
|
||||
return before, after
|
||||
cumulative_len += len(pcm_array)
|
||||
return np.concatenate(cumulative_pcm), []
|
||||
|
||||
async def get_all_from_queue(queue):
|
||||
items = []
|
||||
|
||||
first_item = await queue.get()
|
||||
queue.task_done()
|
||||
if first_item is SENTINEL:
|
||||
return first_item
|
||||
if isinstance(first_item, Silence):
|
||||
return first_item
|
||||
items.append(first_item)
|
||||
|
||||
while True:
|
||||
if not queue._queue:
|
||||
break
|
||||
next_item = queue._queue[0]
|
||||
if next_item is SENTINEL:
|
||||
break
|
||||
if isinstance(next_item, Silence):
|
||||
break
|
||||
items.append(await queue.get())
|
||||
queue.task_done()
|
||||
if isinstance(items[0], np.ndarray):
|
||||
return np.concatenate(items)
|
||||
else: #translation
|
||||
return items
|
||||
|
||||
class AudioProcessor:
|
||||
"""
|
||||
Processes audio streams for transcription and diarization.
|
||||
Handles audio processing, state management, and result formatting.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the audio processor with configuration, models, and state."""
|
||||
|
||||
if 'transcription_engine' in kwargs and isinstance(kwargs['transcription_engine'], TranscriptionEngine):
|
||||
models = kwargs['transcription_engine']
|
||||
else:
|
||||
models = TranscriptionEngine(**kwargs)
|
||||
|
||||
# Audio processing settings
|
||||
self.args = models.args
|
||||
self.sample_rate = 16000
|
||||
self.channels = 1
|
||||
self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size)
|
||||
self.bytes_per_sample = 2
|
||||
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
||||
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
||||
self.is_pcm_input = self.args.pcm_input
|
||||
|
||||
# State management
|
||||
self.is_stopping = False
|
||||
self.silence = True
|
||||
self.silence_duration = 0.0
|
||||
self.start_silence = None
|
||||
self.last_silence_dispatch_time = None
|
||||
self.state = State()
|
||||
self.state_light = StateLight()
|
||||
self.lock = asyncio.Lock()
|
||||
self.sep = " " # Default separator
|
||||
self.last_response_content = FrontData()
|
||||
self.last_detected_speaker = None
|
||||
self.speaker_languages = {}
|
||||
|
||||
self.tokens_alignment = TokensAlignment(self.state_light, self.args, self.sep)
|
||||
self.beg_loop = None
|
||||
|
||||
# Models and processing
|
||||
self.asr = models.asr
|
||||
self.vac_model = models.vac_model
|
||||
if self.args.vac:
|
||||
self.vac = FixedVADIterator(models.vac_model)
|
||||
else:
|
||||
self.vac = None
|
||||
|
||||
self.ffmpeg_manager = None
|
||||
self.ffmpeg_reader_task = None
|
||||
self._ffmpeg_error = None
|
||||
|
||||
if not self.is_pcm_input:
|
||||
self.ffmpeg_manager = FFmpegManager(
|
||||
sample_rate=self.sample_rate,
|
||||
channels=self.channels
|
||||
)
|
||||
async def handle_ffmpeg_error(error_type: str):
|
||||
logger.error(f"FFmpeg error: {error_type}")
|
||||
self._ffmpeg_error = error_type
|
||||
self.ffmpeg_manager.on_error_callback = handle_ffmpeg_error
|
||||
|
||||
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
||||
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
||||
self.translation_queue = asyncio.Queue() if self.args.target_language else None
|
||||
self.pcm_buffer = bytearray()
|
||||
self.total_pcm_samples = 0
|
||||
self.end_buffer = 0.0
|
||||
self.transcription_task = None
|
||||
self.diarization_task = None
|
||||
self.translation_task = None
|
||||
self.watchdog_task = None
|
||||
self.all_tasks_for_cleanup = []
|
||||
|
||||
self.transcription = None
|
||||
self.translation = None
|
||||
self.diarization = None
|
||||
|
||||
if self.args.transcription:
|
||||
self.transcription = online_factory(self.args, models.asr)
|
||||
self.sep = self.transcription.asr.sep
|
||||
if self.args.diarization:
|
||||
self.diarization = online_diarization_factory(self.args, models.diarization_model)
|
||||
if models.translation_model:
|
||||
self.translation = online_translation_factory(self.args, models.translation_model)
|
||||
|
||||
async def _push_silence_event(self, silence_buffer: Silence):
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(silence_buffer)
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(silence_buffer)
|
||||
if self.translation_queue:
|
||||
await self.translation_queue.put(silence_buffer)
|
||||
|
||||
async def _begin_silence(self):
|
||||
if self.silence:
|
||||
return
|
||||
self.silence = True
|
||||
now = time()
|
||||
self.start_silence = now
|
||||
self.last_silence_dispatch_time = now
|
||||
await self._push_silence_event(Silence(is_starting=True))
|
||||
|
||||
async def _end_silence(self):
|
||||
if not self.silence:
|
||||
return
|
||||
now = time()
|
||||
duration = now - (self.last_silence_dispatch_time if self.last_silence_dispatch_time else self.beg_loop)
|
||||
await self._push_silence_event(Silence(duration=duration, has_ended=True))
|
||||
self.last_silence_dispatch_time = now
|
||||
self.silence = False
|
||||
self.start_silence = None
|
||||
self.last_silence_dispatch_time = None
|
||||
|
||||
async def _enqueue_active_audio(self, pcm_chunk: np.ndarray):
|
||||
if pcm_chunk is None or pcm_chunk.size == 0:
|
||||
return
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(pcm_chunk.copy())
|
||||
if self.args.diarization and self.diarization_queue:
|
||||
await self.diarization_queue.put(pcm_chunk.copy())
|
||||
self.silence_duration = 0.0
|
||||
|
||||
def _slice_before_silence(self, pcm_array, chunk_sample_start, silence_sample):
|
||||
if silence_sample is None:
|
||||
return None
|
||||
relative_index = int(silence_sample - chunk_sample_start)
|
||||
if relative_index <= 0:
|
||||
return None
|
||||
split_index = min(relative_index, len(pcm_array))
|
||||
if split_index <= 0:
|
||||
return None
|
||||
return pcm_array[:split_index]
|
||||
|
||||
def convert_pcm_to_float(self, pcm_buffer):
|
||||
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
||||
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
async def get_current_state(self):
|
||||
"""Get current state."""
|
||||
async with self.lock:
|
||||
current_time = time()
|
||||
|
||||
remaining_transcription = 0
|
||||
if self.end_buffer > 0:
|
||||
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 1))
|
||||
|
||||
remaining_diarization = 0
|
||||
if self.state.tokens:
|
||||
latest_end = max(self.end_buffer, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
remaining_diarization = max(0, round(latest_end - self.state.end_attributed_speaker, 1))
|
||||
|
||||
self.state.remaining_time_transcription = remaining_transcription
|
||||
self.state.remaining_time_diarization = remaining_diarization
|
||||
|
||||
return self.state
|
||||
|
||||
async def ffmpeg_stdout_reader(self):
|
||||
"""Read audio data from FFmpeg stdout and process it into the PCM pipeline."""
|
||||
beg = time()
|
||||
while True:
|
||||
try:
|
||||
if self.is_stopping:
|
||||
logger.info("Stopping ffmpeg_stdout_reader due to stopping flag.")
|
||||
break
|
||||
|
||||
state = await self.ffmpeg_manager.get_state() if self.ffmpeg_manager else FFmpegState.STOPPED
|
||||
if state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, cannot read data")
|
||||
break
|
||||
elif state == FFmpegState.STOPPED:
|
||||
logger.info("FFmpeg is stopped")
|
||||
break
|
||||
elif state != FFmpegState.RUNNING:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
current_time = time()
|
||||
elapsed_time = max(0.0, current_time - beg)
|
||||
buffer_size = max(int(32000 * elapsed_time), 4096) # dynamic read
|
||||
beg = current_time
|
||||
|
||||
chunk = await self.ffmpeg_manager.read_data(buffer_size)
|
||||
if not chunk:
|
||||
# No data currently available
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
self.pcm_buffer.extend(chunk)
|
||||
await self.handle_pcm_data()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("ffmpeg_stdout_reader cancelled.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
||||
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.info("FFmpeg stdout processing finished. Signaling downstream processors if needed.")
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
if self.diarization:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
if self.translation:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
async def transcription_processor(self):
|
||||
"""Process audio chunks for transcription."""
|
||||
cumulative_pcm_duration_stream_time = 0.0
|
||||
|
||||
while True:
|
||||
try:
|
||||
# item = await self.transcription_queue.get()
|
||||
item = await get_all_from_queue(self.transcription_queue)
|
||||
if item is SENTINEL:
|
||||
logger.debug("Transcription processor received sentinel. Finishing.")
|
||||
break
|
||||
|
||||
asr_internal_buffer_duration_s = len(getattr(self.transcription, 'audio_buffer', [])) / self.transcription.SAMPLING_RATE
|
||||
transcription_lag_s = max(0.0, time() - self.beg_loop - self.end_buffer)
|
||||
asr_processing_logs = f"internal_buffer={asr_internal_buffer_duration_s:.2f}s | lag={transcription_lag_s:.2f}s |"
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
new_tokens = []
|
||||
current_audio_processed_upto = self.end_buffer
|
||||
|
||||
if isinstance(item, Silence):
|
||||
if item.is_starting:
|
||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(
|
||||
self.transcription.start_silence
|
||||
)
|
||||
asr_processing_logs += f" + Silence starting"
|
||||
if item.has_ended:
|
||||
asr_processing_logs += f" + Silence of = {item.duration:.2f}s"
|
||||
cumulative_pcm_duration_stream_time += item.duration
|
||||
current_audio_processed_upto = cumulative_pcm_duration_stream_time
|
||||
self.transcription.end_silence(item.duration, self.state.tokens[-1].end if self.state.tokens else 0)
|
||||
if self.state.tokens:
|
||||
asr_processing_logs += f" | last_end = {self.state.tokens[-1].end} |"
|
||||
logger.info(asr_processing_logs)
|
||||
new_tokens = new_tokens or []
|
||||
current_audio_processed_upto = max(current_audio_processed_upto, stream_time_end_of_current_pcm)
|
||||
elif isinstance(item, ChangeSpeaker):
|
||||
self.transcription.new_speaker(item)
|
||||
continue
|
||||
elif isinstance(item, np.ndarray):
|
||||
pcm_array = item
|
||||
logger.info(asr_processing_logs)
|
||||
cumulative_pcm_duration_stream_time += len(pcm_array) / self.sample_rate
|
||||
stream_time_end_of_current_pcm = cumulative_pcm_duration_stream_time
|
||||
self.transcription.insert_audio_chunk(pcm_array, stream_time_end_of_current_pcm)
|
||||
new_tokens, current_audio_processed_upto = await asyncio.to_thread(self.transcription.process_iter)
|
||||
new_tokens = new_tokens or []
|
||||
|
||||
_buffer_transcript = self.transcription.get_buffer()
|
||||
buffer_text = _buffer_transcript.text
|
||||
|
||||
if new_tokens:
|
||||
validated_text = self.sep.join([t.text for t in new_tokens])
|
||||
if buffer_text.startswith(validated_text):
|
||||
_buffer_transcript.text = buffer_text[len(validated_text):].lstrip()
|
||||
|
||||
candidate_end_times = [self.end_buffer]
|
||||
|
||||
if new_tokens:
|
||||
candidate_end_times.append(new_tokens[-1].end)
|
||||
|
||||
if _buffer_transcript.end is not None:
|
||||
candidate_end_times.append(_buffer_transcript.end)
|
||||
|
||||
candidate_end_times.append(current_audio_processed_upto)
|
||||
|
||||
async with self.lock:
|
||||
self.state.tokens.extend(new_tokens)
|
||||
self.state.buffer_transcription = _buffer_transcript
|
||||
self.end_buffer = max(candidate_end_times)
|
||||
self.state_light.new_tokens = new_tokens
|
||||
self.state_light.new_tokens += 1
|
||||
self.state_light.new_tokens_buffer = _buffer_transcript
|
||||
|
||||
if self.translation_queue:
|
||||
for token in new_tokens:
|
||||
await self.translation_queue.put(token)
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in transcription_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
if 'pcm_array' in locals() and pcm_array is not SENTINEL : # Check if pcm_array was assigned from queue
|
||||
self.transcription_queue.task_done()
|
||||
|
||||
if self.is_stopping:
|
||||
logger.info("Transcription processor finishing due to stopping flag.")
|
||||
if self.diarization_queue:
|
||||
await self.diarization_queue.put(SENTINEL)
|
||||
if self.translation_queue:
|
||||
await self.translation_queue.put(SENTINEL)
|
||||
|
||||
logger.info("Transcription processor task finished.")
|
||||
|
||||
|
||||
async def diarization_processor(self):
|
||||
while True:
|
||||
try:
|
||||
item = await get_all_from_queue(self.diarization_queue)
|
||||
if item is SENTINEL:
|
||||
break
|
||||
elif type(item) is Silence:
|
||||
if item.has_ended:
|
||||
self.diarization.insert_silence(item.duration)
|
||||
continue
|
||||
|
||||
self.diarization.insert_audio_chunk(item)
|
||||
diarization_segments = await self.diarization.diarize()
|
||||
self.state_light.new_diarization = diarization_segments
|
||||
self.state_light.new_diarization_index += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in diarization_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
logger.info("Diarization processor task finished.")
|
||||
|
||||
async def translation_processor(self):
|
||||
# the idea is to ignore diarization for the moment. We use only transcription tokens.
|
||||
# And the speaker is attributed given the segments used for the translation
|
||||
# in the future we want to have different languages for each speaker etc, so it will be more complex.
|
||||
while True:
|
||||
try:
|
||||
tokens_to_process = await get_all_from_queue(self.translation_queue)
|
||||
if tokens_to_process is SENTINEL:
|
||||
logger.debug("Translation processor received sentinel. Finishing.")
|
||||
self.translation_queue.task_done()
|
||||
break
|
||||
elif type(tokens_to_process) is Silence:
|
||||
if tokens_to_process.has_ended:
|
||||
self.translation.insert_silence(tokens_to_process.duration)
|
||||
continue
|
||||
if tokens_to_process:
|
||||
self.translation.insert_tokens(tokens_to_process)
|
||||
translation_validated_segments, buffer_translation = await asyncio.to_thread(self.translation.process)
|
||||
async with self.lock:
|
||||
self.state.translation_validated_segments = translation_validated_segments
|
||||
self.state.buffer_translation = buffer_translation
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in translation_processor: {e}")
|
||||
logger.warning(f"Traceback: {traceback.format_exc()}")
|
||||
logger.info("Translation processor task finished.")
|
||||
|
||||
async def results_formatter(self):
|
||||
"""Format processing results for output."""
|
||||
while True:
|
||||
try:
|
||||
if self._ffmpeg_error:
|
||||
yield FrontData(status="error", error=f"FFmpeg error: {self._ffmpeg_error}")
|
||||
self._ffmpeg_error = None
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
state = await self.get_current_state()
|
||||
self.tokens_alignment.compute_punctuations_segments()
|
||||
lines, undiarized_text = format_output(
|
||||
state,
|
||||
self.silence,
|
||||
args = self.args,
|
||||
sep=self.sep
|
||||
)
|
||||
if lines and lines[-1].speaker == -2:
|
||||
buffer_transcription = Transcript()
|
||||
else:
|
||||
buffer_transcription = state.buffer_transcription
|
||||
|
||||
buffer_diarization = ''
|
||||
if undiarized_text:
|
||||
buffer_diarization = self.sep.join(undiarized_text)
|
||||
|
||||
async with self.lock:
|
||||
self.state.end_attributed_speaker = state.end_attributed_speaker
|
||||
|
||||
buffer_translation_text = ''
|
||||
if state.buffer_translation:
|
||||
raw_buffer_translation = getattr(state.buffer_translation, 'text', state.buffer_translation)
|
||||
if raw_buffer_translation:
|
||||
buffer_translation_text = raw_buffer_translation.strip()
|
||||
|
||||
response_status = "active_transcription"
|
||||
if not state.tokens and not buffer_transcription and not buffer_diarization:
|
||||
response_status = "no_audio_detected"
|
||||
lines = []
|
||||
elif not lines:
|
||||
lines = [Line(
|
||||
speaker=1,
|
||||
start=state.end_buffer,
|
||||
end=state.end_buffer
|
||||
)]
|
||||
|
||||
response = FrontData(
|
||||
status=response_status,
|
||||
lines=lines,
|
||||
buffer_transcription=buffer_transcription.text.strip(),
|
||||
buffer_diarization=buffer_diarization,
|
||||
buffer_translation=buffer_translation_text,
|
||||
remaining_time_transcription=state.remaining_time_transcription,
|
||||
remaining_time_diarization=state.remaining_time_diarization if self.args.diarization else 0
|
||||
)
|
||||
|
||||
should_push = (response != self.last_response_content)
|
||||
if should_push and (lines or buffer_transcription or buffer_diarization or response_status == "no_audio_detected"):
|
||||
yield response
|
||||
self.last_response_content = response
|
||||
|
||||
if self.is_stopping and self._processing_tasks_done():
|
||||
logger.info("Results formatter: All upstream processors are done and in stopping state. Terminating.")
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception in results_formatter. Traceback: {traceback.format_exc()}")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def create_tasks(self):
|
||||
"""Create and start processing tasks."""
|
||||
self.all_tasks_for_cleanup = []
|
||||
processing_tasks_for_watchdog = []
|
||||
|
||||
# If using FFmpeg (non-PCM input), start it and spawn stdout reader
|
||||
if not self.is_pcm_input:
|
||||
success = await self.ffmpeg_manager.start()
|
||||
if not success:
|
||||
logger.error("Failed to start FFmpeg manager")
|
||||
async def error_generator():
|
||||
yield FrontData(
|
||||
status="error",
|
||||
error="FFmpeg failed to start. Please check that FFmpeg is installed."
|
||||
)
|
||||
return error_generator()
|
||||
self.ffmpeg_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
||||
self.all_tasks_for_cleanup.append(self.ffmpeg_reader_task)
|
||||
processing_tasks_for_watchdog.append(self.ffmpeg_reader_task)
|
||||
|
||||
if self.transcription:
|
||||
self.transcription_task = asyncio.create_task(self.transcription_processor())
|
||||
self.all_tasks_for_cleanup.append(self.transcription_task)
|
||||
processing_tasks_for_watchdog.append(self.transcription_task)
|
||||
|
||||
if self.diarization:
|
||||
self.diarization_task = asyncio.create_task(self.diarization_processor())
|
||||
self.all_tasks_for_cleanup.append(self.diarization_task)
|
||||
processing_tasks_for_watchdog.append(self.diarization_task)
|
||||
|
||||
if self.translation:
|
||||
self.translation_task = asyncio.create_task(self.translation_processor())
|
||||
self.all_tasks_for_cleanup.append(self.translation_task)
|
||||
processing_tasks_for_watchdog.append(self.translation_task)
|
||||
|
||||
# Monitor overall system health
|
||||
self.watchdog_task = asyncio.create_task(self.watchdog(processing_tasks_for_watchdog))
|
||||
self.all_tasks_for_cleanup.append(self.watchdog_task)
|
||||
|
||||
return self.results_formatter()
|
||||
|
||||
async def watchdog(self, tasks_to_monitor):
|
||||
"""Monitors the health of critical processing tasks."""
|
||||
tasks_remaining = [task for task in tasks_to_monitor if task]
|
||||
while True:
|
||||
try:
|
||||
if not tasks_remaining:
|
||||
logger.info("Watchdog task finishing: all monitored tasks completed.")
|
||||
return
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
for i, task in enumerate(list(tasks_remaining)):
|
||||
if task.done():
|
||||
exc = task.exception()
|
||||
task_name = task.get_name() if hasattr(task, 'get_name') else f"Monitored Task {i}"
|
||||
if exc:
|
||||
logger.error(f"{task_name} unexpectedly completed with exception: {exc}")
|
||||
else:
|
||||
logger.info(f"{task_name} completed normally.")
|
||||
tasks_remaining.remove(task)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Watchdog task cancelled.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in watchdog task: {e}", exc_info=True)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources when processing is complete."""
|
||||
logger.info("Starting cleanup of AudioProcessor resources.")
|
||||
self.is_stopping = True
|
||||
for task in self.all_tasks_for_cleanup:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
created_tasks = [t for t in self.all_tasks_for_cleanup if t]
|
||||
if created_tasks:
|
||||
await asyncio.gather(*created_tasks, return_exceptions=True)
|
||||
logger.info("All processing tasks cancelled or finished.")
|
||||
|
||||
if not self.is_pcm_input and self.ffmpeg_manager:
|
||||
try:
|
||||
await self.ffmpeg_manager.stop()
|
||||
logger.info("FFmpeg manager stopped.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping FFmpeg manager: {e}")
|
||||
if self.diarization:
|
||||
self.diarization.close()
|
||||
logger.info("AudioProcessor cleanup complete.")
|
||||
|
||||
def _processing_tasks_done(self):
|
||||
"""Return True when all active processing tasks have completed."""
|
||||
tasks_to_check = [
|
||||
self.transcription_task,
|
||||
self.diarization_task,
|
||||
self.translation_task,
|
||||
self.ffmpeg_reader_task,
|
||||
]
|
||||
return all(task.done() for task in tasks_to_check if task)
|
||||
|
||||
|
||||
async def process_audio(self, message):
|
||||
"""Process incoming audio data."""
|
||||
|
||||
if not self.beg_loop:
|
||||
self.beg_loop = time()
|
||||
|
||||
if not message:
|
||||
logger.info("Empty audio message received, initiating stop sequence.")
|
||||
self.is_stopping = True
|
||||
|
||||
if self.transcription_queue:
|
||||
await self.transcription_queue.put(SENTINEL)
|
||||
|
||||
if not self.is_pcm_input and self.ffmpeg_manager:
|
||||
await self.ffmpeg_manager.stop()
|
||||
|
||||
return
|
||||
|
||||
if self.is_stopping:
|
||||
logger.warning("AudioProcessor is stopping. Ignoring incoming audio.")
|
||||
return
|
||||
|
||||
if self.is_pcm_input:
|
||||
self.pcm_buffer.extend(message)
|
||||
await self.handle_pcm_data()
|
||||
else:
|
||||
if not self.ffmpeg_manager:
|
||||
logger.error("FFmpeg manager not initialized for non-PCM input.")
|
||||
return
|
||||
success = await self.ffmpeg_manager.write_data(message)
|
||||
if not success:
|
||||
ffmpeg_state = await self.ffmpeg_manager.get_state()
|
||||
if ffmpeg_state == FFmpegState.FAILED:
|
||||
logger.error("FFmpeg is in FAILED state, cannot process audio")
|
||||
else:
|
||||
logger.warning("Failed to write audio data to FFmpeg")
|
||||
|
||||
async def handle_pcm_data(self):
|
||||
# Process when enough data
|
||||
if len(self.pcm_buffer) < self.bytes_per_sec:
|
||||
return
|
||||
|
||||
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
||||
logger.warning(
|
||||
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
|
||||
f"Consider using a smaller model."
|
||||
)
|
||||
|
||||
chunk_size = min(len(self.pcm_buffer), self.max_bytes_per_sec)
|
||||
aligned_chunk_size = (chunk_size // self.bytes_per_sample) * self.bytes_per_sample
|
||||
|
||||
if aligned_chunk_size == 0:
|
||||
return
|
||||
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:aligned_chunk_size])
|
||||
self.pcm_buffer = self.pcm_buffer[aligned_chunk_size:]
|
||||
|
||||
num_samples = len(pcm_array)
|
||||
chunk_sample_start = self.total_pcm_samples
|
||||
chunk_sample_end = chunk_sample_start + num_samples
|
||||
|
||||
res = None
|
||||
if self.args.vac:
|
||||
res = self.vac(pcm_array)
|
||||
|
||||
if res is not None:
|
||||
silence_detected = res.get("end", 0) > res.get("start", 0)
|
||||
if silence_detected and not self.silence:
|
||||
pre_silence_chunk = self._slice_before_silence(
|
||||
pcm_array, chunk_sample_start, res.get("end")
|
||||
)
|
||||
if pre_silence_chunk is not None and pre_silence_chunk.size > 0:
|
||||
await self._enqueue_active_audio(pre_silence_chunk)
|
||||
await self._begin_silence()
|
||||
elif self.silence:
|
||||
await self._end_silence()
|
||||
|
||||
if not self.silence:
|
||||
await self._enqueue_active_audio(pcm_array)
|
||||
|
||||
self.total_pcm_samples = chunk_sample_end
|
||||
|
||||
if not self.args.transcription and not self.args.diarization:
|
||||
await asyncio.sleep(0.1)
|
||||
41
whisperlivekit/backend_support.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
import platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def module_available(module_name):
|
||||
"""Return True if the given module can be imported."""
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
|
||||
|
||||
def mlx_backend_available(warn_on_missing = False):
|
||||
is_macos = platform.system() == "Darwin"
|
||||
is_arm = platform.machine() == "arm64"
|
||||
available = (
|
||||
is_macos
|
||||
and is_arm
|
||||
and module_available("mlx_whisper")
|
||||
)
|
||||
if not available and warn_on_missing and is_macos and is_arm:
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nMLX Whisper not found but you are on Apple Silicon. "
|
||||
"Consider installing mlx-whisper for better performance: "
|
||||
"`pip install mlx-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
|
||||
|
||||
def faster_backend_available(warn_on_missing = False):
|
||||
available = module_available("faster_whisper")
|
||||
if not available and warn_on_missing and platform.system() != "Darwin":
|
||||
logger.warning(
|
||||
"=" * 50
|
||||
+ "\nFaster-Whisper not found. Consider installing faster-whisper "
|
||||
"for better performance: `pip install faster-whisper`\n"
|
||||
+ "=" * 50
|
||||
)
|
||||
return available
|
||||
127
whisperlivekit/basic_server.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from whisperlivekit import TranscriptionEngine, AudioProcessor, get_inline_ui_html, parse_args
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
args = parse_args()
|
||||
transcription_engine = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global transcription_engine
|
||||
transcription_engine = TranscriptionEngine(
|
||||
**vars(args),
|
||||
)
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(get_inline_ui_html())
|
||||
|
||||
|
||||
async def handle_websocket_results(websocket, results_generator):
|
||||
"""Consumes results from the audio processor and sends them via WebSocket."""
|
||||
try:
|
||||
async for response in results_generator:
|
||||
await websocket.send_json(response.to_dict())
|
||||
# when the results_generator finishes it means all audio has been processed
|
||||
logger.info("Results generator finished. Sending 'ready_to_stop' to client.")
|
||||
await websocket.send_json({"type": "ready_to_stop"})
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected while handling results (client likely closed connection).")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in WebSocket results handler: {e}")
|
||||
|
||||
|
||||
@app.websocket("/asr")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global transcription_engine
|
||||
audio_processor = AudioProcessor(
|
||||
transcription_engine=transcription_engine,
|
||||
)
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket connection opened.")
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "config", "useAudioWorklet": bool(args.pcm_input)})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send config to client: {e}")
|
||||
|
||||
results_generator = await audio_processor.create_tasks()
|
||||
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_bytes()
|
||||
await audio_processor.process_audio(message)
|
||||
except KeyError as e:
|
||||
if 'bytes' in str(e):
|
||||
logger.warning(f"Client has closed the connection.")
|
||||
else:
|
||||
logger.error(f"Unexpected KeyError in websocket_endpoint: {e}", exc_info=True)
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected by client during message receiving loop.")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in websocket_endpoint main loop: {e}", exc_info=True)
|
||||
finally:
|
||||
logger.info("Cleaning up WebSocket endpoint...")
|
||||
if not websocket_task.done():
|
||||
websocket_task.cancel()
|
||||
try:
|
||||
await websocket_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("WebSocket results handler task was cancelled.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception while awaiting websocket_task completion: {e}")
|
||||
|
||||
await audio_processor.cleanup()
|
||||
logger.info("WebSocket endpoint cleaned up successfully.")
|
||||
|
||||
def main():
|
||||
"""Entry point for the CLI command."""
|
||||
import uvicorn
|
||||
|
||||
uvicorn_kwargs = {
|
||||
"app": "whisperlivekit.basic_server:app",
|
||||
"host":args.host,
|
||||
"port":args.port,
|
||||
"reload": False,
|
||||
"log_level": "info",
|
||||
"lifespan": "on",
|
||||
}
|
||||
|
||||
ssl_kwargs = {}
|
||||
if args.ssl_certfile or args.ssl_keyfile:
|
||||
if not (args.ssl_certfile and args.ssl_keyfile):
|
||||
raise ValueError("Both --ssl-certfile and --ssl-keyfile must be specified together.")
|
||||
ssl_kwargs = {
|
||||
"ssl_certfile": args.ssl_certfile,
|
||||
"ssl_keyfile": args.ssl_keyfile
|
||||
}
|
||||
|
||||
if ssl_kwargs:
|
||||
uvicorn_kwargs = {**uvicorn_kwargs, **ssl_kwargs}
|
||||
if args.forwarded_allow_ips:
|
||||
uvicorn_kwargs = { **uvicorn_kwargs, "forwarded_allow_ips" : args.forwarded_allow_ips }
|
||||
|
||||
uvicorn.run(**uvicorn_kwargs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
195
whisperlivekit/core.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from whisperlivekit.local_agreement.whisper_online import backend_factory
|
||||
from whisperlivekit.simul_whisper import SimulStreamingASR
|
||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||
from argparse import Namespace
|
||||
import sys
|
||||
import logging
|
||||
|
||||
def update_with_kwargs(_dict, kwargs):
|
||||
_dict.update({
|
||||
k: v for k, v in kwargs.items() if k in _dict
|
||||
})
|
||||
return _dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TranscriptionEngine:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if TranscriptionEngine._initialized:
|
||||
return
|
||||
|
||||
global_params = {
|
||||
"host": "localhost",
|
||||
"port": 8000,
|
||||
"diarization": False,
|
||||
"punctuation_split": False,
|
||||
"target_language": "",
|
||||
"vac": True,
|
||||
"vac_onnx": False,
|
||||
"vac_chunk_size": 0.04,
|
||||
"log_level": "DEBUG",
|
||||
"ssl_certfile": None,
|
||||
"ssl_keyfile": None,
|
||||
"forwarded_allow_ips": None,
|
||||
"transcription": True,
|
||||
"vad": True,
|
||||
"pcm_input": False,
|
||||
"disable_punctuation_split" : False,
|
||||
"diarization_backend": "sortformer",
|
||||
"backend_policy": "simulstreaming",
|
||||
"backend": "auto",
|
||||
}
|
||||
global_params = update_with_kwargs(global_params, kwargs)
|
||||
|
||||
transcription_common_params = {
|
||||
"warmup_file": None,
|
||||
"min_chunk_size": 0.1,
|
||||
"model_size": "base",
|
||||
"model_cache_dir": None,
|
||||
"model_dir": None,
|
||||
"model_path": None,
|
||||
"lan": "auto",
|
||||
"direct_english_translation": False,
|
||||
}
|
||||
transcription_common_params = update_with_kwargs(transcription_common_params, kwargs)
|
||||
|
||||
if transcription_common_params['model_size'].endswith(".en"):
|
||||
transcription_common_params["lan"] = "en"
|
||||
if 'no_transcription' in kwargs:
|
||||
global_params['transcription'] = not global_params['no_transcription']
|
||||
if 'no_vad' in kwargs:
|
||||
global_params['vad'] = not kwargs['no_vad']
|
||||
if 'no_vac' in kwargs:
|
||||
global_params['vac'] = not kwargs['no_vac']
|
||||
|
||||
self.args = Namespace(**{**global_params, **transcription_common_params})
|
||||
|
||||
self.asr = None
|
||||
self.tokenizer = None
|
||||
self.diarization = None
|
||||
self.vac_model = None
|
||||
|
||||
if self.args.vac:
|
||||
from whisperlivekit.silero_vad_iterator import load_silero_vad
|
||||
# Use ONNX if specified, otherwise use JIT (default)
|
||||
use_onnx = kwargs.get('vac_onnx', False)
|
||||
self.vac_model = load_silero_vad(onnx=use_onnx)
|
||||
|
||||
backend_policy = self.args.backend_policy
|
||||
if self.args.transcription:
|
||||
if backend_policy == "simulstreaming":
|
||||
simulstreaming_params = {
|
||||
"disable_fast_encoder": False,
|
||||
"custom_alignment_heads": None,
|
||||
"frame_threshold": 25,
|
||||
"beams": 1,
|
||||
"decoder_type": None,
|
||||
"audio_max_len": 20.0,
|
||||
"audio_min_len": 0.0,
|
||||
"cif_ckpt_path": None,
|
||||
"never_fire": False,
|
||||
"init_prompt": None,
|
||||
"static_init_prompt": None,
|
||||
"max_context_tokens": None,
|
||||
"preload_model_count": 1,
|
||||
}
|
||||
simulstreaming_params = update_with_kwargs(simulstreaming_params, kwargs)
|
||||
|
||||
self.tokenizer = None
|
||||
self.asr = SimulStreamingASR(
|
||||
**transcription_common_params,
|
||||
**simulstreaming_params,
|
||||
backend=self.args.backend,
|
||||
)
|
||||
logger.info(
|
||||
"Using SimulStreaming policy with %s backend",
|
||||
getattr(self.asr, "encoder_backend", "whisper"),
|
||||
)
|
||||
else:
|
||||
|
||||
whisperstreaming_params = {
|
||||
"buffer_trimming": "segment",
|
||||
"confidence_validation": False,
|
||||
"buffer_trimming_sec": 15,
|
||||
}
|
||||
whisperstreaming_params = update_with_kwargs(whisperstreaming_params, kwargs)
|
||||
|
||||
self.asr = backend_factory(
|
||||
backend=self.args.backend,
|
||||
**transcription_common_params,
|
||||
**whisperstreaming_params,
|
||||
)
|
||||
logger.info(
|
||||
"Using LocalAgreement policy with %s backend",
|
||||
getattr(self.asr, "backend_choice", self.asr.__class__.__name__),
|
||||
)
|
||||
|
||||
if self.args.diarization:
|
||||
if self.args.diarization_backend == "diart":
|
||||
from whisperlivekit.diarization.diart_backend import DiartDiarization
|
||||
diart_params = {
|
||||
"segmentation_model": "pyannote/segmentation-3.0",
|
||||
"embedding_model": "pyannote/embedding",
|
||||
}
|
||||
diart_params = update_with_kwargs(diart_params, kwargs)
|
||||
self.diarization_model = DiartDiarization(
|
||||
block_duration=self.args.min_chunk_size,
|
||||
**diart_params
|
||||
)
|
||||
elif self.args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarization
|
||||
self.diarization_model = SortformerDiarization()
|
||||
|
||||
self.translation_model = None
|
||||
if self.args.target_language:
|
||||
if self.args.lan == 'auto' and backend_policy != "simulstreaming":
|
||||
raise Exception('Translation cannot be set with language auto when transcription backend is not simulstreaming')
|
||||
else:
|
||||
try:
|
||||
from nllw import load_model
|
||||
except:
|
||||
raise Exception('To use translation, you must install nllw: `pip install nllw`')
|
||||
translation_params = {
|
||||
"nllb_backend": "transformers",
|
||||
"nllb_size": "600M"
|
||||
}
|
||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
||||
TranscriptionEngine._initialized = True
|
||||
|
||||
|
||||
def online_factory(args, asr):
|
||||
if args.backend_policy == "simulstreaming":
|
||||
from whisperlivekit.simul_whisper import SimulStreamingOnlineProcessor
|
||||
online = SimulStreamingOnlineProcessor(asr)
|
||||
else:
|
||||
online = OnlineASRProcessor(asr)
|
||||
return online
|
||||
|
||||
|
||||
def online_diarization_factory(args, diarization_backend):
|
||||
if args.diarization_backend == "diart":
|
||||
online = diarization_backend
|
||||
# Not the best here, since several user/instances will share the same backend, but diart is not SOTA anymore and sortformer is recommended
|
||||
|
||||
if args.diarization_backend == "sortformer":
|
||||
from whisperlivekit.diarization.sortformer_backend import SortformerDiarizationOnline
|
||||
online = SortformerDiarizationOnline(shared_model=diarization_backend)
|
||||
return online
|
||||
|
||||
|
||||
def online_translation_factory(args, translation_model):
|
||||
#should be at speaker level in the future:
|
||||
#one shared nllb model for all speaker
|
||||
#one tokenizer per speaker/language
|
||||
from nllw import OnlineTranslation
|
||||
return OnlineTranslation(translation_model, [args.lan], [args.target_language])
|
||||
0
whisperlivekit/diarization/__init__.py
Normal file
288
whisperlivekit/diarization/diart_backend.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
from queue import SimpleQueue, Empty
|
||||
|
||||
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
||||
from diart.inference import StreamingInference
|
||||
from diart.sources import AudioSource
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
from diart.sources import MicrophoneAudioSource
|
||||
from rx.core import Observer
|
||||
from typing import Tuple, Any, List
|
||||
from pyannote.core import Annotation
|
||||
import diart.models as m
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_number(s: str) -> int:
|
||||
m = re.search(r'\d+', s)
|
||||
return int(m.group()) if m else None
|
||||
|
||||
class DiarizationObserver(Observer):
|
||||
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
||||
|
||||
def __init__(self):
|
||||
self.diarization_segments = []
|
||||
self.processed_time = 0
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
|
||||
def on_next(self, value: Tuple[Annotation, Any]):
|
||||
annotation, audio = value
|
||||
|
||||
logger.debug("\n--- New Diarization Result ---")
|
||||
|
||||
duration = audio.extent.end - audio.extent.start
|
||||
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
|
||||
logger.debug(f"Audio shape: {audio.data.shape}")
|
||||
|
||||
with self.segment_lock:
|
||||
if audio.extent.end > self.processed_time:
|
||||
self.processed_time = audio.extent.end
|
||||
if annotation and len(annotation._labels) > 0:
|
||||
logger.debug("\nSpeaker segments:")
|
||||
for speaker, label in annotation._labels.items():
|
||||
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
|
||||
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
||||
self.diarization_segments.append(SpeakerSegment(
|
||||
speaker=speaker,
|
||||
start=start + self.global_time_offset,
|
||||
end=end + self.global_time_offset
|
||||
))
|
||||
else:
|
||||
logger.debug("\nNo speakers detected in this segment")
|
||||
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.diarization_segments.copy()
|
||||
|
||||
def clear_old_segments(self, older_than: float = 30.0):
|
||||
"""Clear segments older than the specified time."""
|
||||
with self.segment_lock:
|
||||
current_time = self.processed_time
|
||||
self.diarization_segments = [
|
||||
segment for segment in self.diarization_segments
|
||||
if current_time - segment.end < older_than
|
||||
]
|
||||
|
||||
def on_error(self, error):
|
||||
"""Handle an error in the stream."""
|
||||
logger.debug(f"Error in diarization stream: {error}")
|
||||
|
||||
def on_completed(self):
|
||||
"""Handle the completion of the stream."""
|
||||
logger.debug("Diarization stream completed")
|
||||
|
||||
|
||||
class WebSocketAudioSource(AudioSource):
|
||||
"""
|
||||
Buffers incoming audio and releases it in fixed-size chunks at regular intervals.
|
||||
"""
|
||||
def __init__(self, uri: str = "websocket", sample_rate: int = 16000, block_duration: float = 0.5):
|
||||
super().__init__(uri, sample_rate)
|
||||
self.block_duration = block_duration
|
||||
self.block_size = int(np.rint(block_duration * sample_rate))
|
||||
self._queue = SimpleQueue()
|
||||
self._buffer = np.array([], dtype=np.float32)
|
||||
self._buffer_lock = threading.Lock()
|
||||
self._closed = False
|
||||
self._close_event = threading.Event()
|
||||
self._processing_thread = None
|
||||
self._last_chunk_time = time.time()
|
||||
|
||||
def read(self):
|
||||
"""Start processing buffered audio and emit fixed-size chunks."""
|
||||
self._processing_thread = threading.Thread(target=self._process_chunks)
|
||||
self._processing_thread.daemon = True
|
||||
self._processing_thread.start()
|
||||
|
||||
self._close_event.wait()
|
||||
if self._processing_thread:
|
||||
self._processing_thread.join(timeout=2.0)
|
||||
|
||||
def _process_chunks(self):
|
||||
"""Process audio from queue and emit fixed-size chunks at regular intervals."""
|
||||
while not self._closed:
|
||||
try:
|
||||
audio_chunk = self._queue.get(timeout=0.1)
|
||||
|
||||
with self._buffer_lock:
|
||||
self._buffer = np.concatenate([self._buffer, audio_chunk])
|
||||
|
||||
while len(self._buffer) >= self.block_size:
|
||||
chunk = self._buffer[:self.block_size]
|
||||
self._buffer = self._buffer[self.block_size:]
|
||||
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - self._last_chunk_time
|
||||
if time_since_last < self.block_duration:
|
||||
time.sleep(self.block_duration - time_since_last)
|
||||
|
||||
chunk_reshaped = chunk.reshape(1, -1)
|
||||
self.stream.on_next(chunk_reshaped)
|
||||
self._last_chunk_time = time.time()
|
||||
|
||||
except Empty:
|
||||
with self._buffer_lock:
|
||||
if len(self._buffer) > 0 and time.time() - self._last_chunk_time > self.block_duration:
|
||||
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||
padded_chunk[:len(self._buffer)] = self._buffer
|
||||
self._buffer = np.array([], dtype=np.float32)
|
||||
|
||||
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||
self.stream.on_next(chunk_reshaped)
|
||||
self._last_chunk_time = time.time()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio processing thread: {e}")
|
||||
self.stream.on_error(e)
|
||||
break
|
||||
|
||||
with self._buffer_lock:
|
||||
if len(self._buffer) > 0:
|
||||
padded_chunk = np.zeros(self.block_size, dtype=np.float32)
|
||||
padded_chunk[:len(self._buffer)] = self._buffer
|
||||
chunk_reshaped = padded_chunk.reshape(1, -1)
|
||||
self.stream.on_next(chunk_reshaped)
|
||||
|
||||
self.stream.on_completed()
|
||||
|
||||
def close(self):
|
||||
if not self._closed:
|
||||
self._closed = True
|
||||
self._close_event.set()
|
||||
|
||||
def push_audio(self, chunk: np.ndarray):
|
||||
"""Add audio chunk to the processing queue."""
|
||||
if not self._closed:
|
||||
if chunk.ndim > 1:
|
||||
chunk = chunk.flatten()
|
||||
self._queue.put(chunk)
|
||||
logger.debug(f'Added chunk to queue with {len(chunk)} samples')
|
||||
|
||||
|
||||
class DiartDiarization:
|
||||
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False, block_duration: float = 1.5, segmentation_model_name: str = "pyannote/segmentation-3.0", embedding_model_name: str = "pyannote/embedding"):
|
||||
segmentation_model = m.SegmentationModel.from_pretrained(segmentation_model_name)
|
||||
embedding_model = m.EmbeddingModel.from_pretrained(embedding_model_name)
|
||||
|
||||
if config is None:
|
||||
config = SpeakerDiarizationConfig(
|
||||
segmentation=segmentation_model,
|
||||
embedding=embedding_model,
|
||||
)
|
||||
|
||||
self.pipeline = SpeakerDiarization(config=config)
|
||||
self.observer = DiarizationObserver()
|
||||
|
||||
if use_microphone:
|
||||
self.source = MicrophoneAudioSource(block_duration=block_duration)
|
||||
self.custom_source = None
|
||||
else:
|
||||
self.custom_source = WebSocketAudioSource(
|
||||
uri="websocket_source",
|
||||
sample_rate=sample_rate,
|
||||
block_duration=block_duration
|
||||
)
|
||||
self.source = self.custom_source
|
||||
|
||||
self.inference = StreamingInference(
|
||||
pipeline=self.pipeline,
|
||||
source=self.source,
|
||||
do_plot=False,
|
||||
show_progress=False,
|
||||
)
|
||||
self.inference.attach_observers(self.observer)
|
||||
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
||||
|
||||
def insert_silence(self, silence_duration):
|
||||
self.observer.global_time_offset += silence_duration
|
||||
|
||||
async def diarize(self, pcm_array: np.ndarray):
|
||||
"""
|
||||
Process audio data for diarization.
|
||||
Only used when working with WebSocketAudioSource.
|
||||
"""
|
||||
if self.custom_source:
|
||||
self.custom_source.push_audio(pcm_array)
|
||||
# self.observer.clear_old_segments()
|
||||
|
||||
def close(self):
|
||||
"""Close the audio source."""
|
||||
if self.custom_source:
|
||||
self.custom_source.close()
|
||||
|
||||
|
||||
def concatenate_speakers(segments):
|
||||
segments_concatenated = [{"speaker": 1, "begin": 0.0, "end": 0.0}]
|
||||
for segment in segments:
|
||||
speaker = extract_number(segment.speaker) + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
# print("Segments concatenated:")
|
||||
# for entry in segments_concatenated:
|
||||
# print(f"Speaker {entry['speaker']}: {entry['begin']:.2f}s - {entry['end']:.2f}s")
|
||||
return segments_concatenated
|
||||
|
||||
|
||||
def add_speaker_to_tokens(segments, tokens):
|
||||
"""
|
||||
Assign speakers to tokens based on diarization segments, with punctuation-aware boundary adjustment.
|
||||
"""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
segments_concatenated = concatenate_speakers(segments)
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
# print(
|
||||
# f"Token '{token.text}' ('begin': {token.start:.2f}, 'end': {token.end:.2f}) "
|
||||
# f"assigned to Speaker {segment['speaker']} ('segment': {segment['begin']:.2f}-{segment['end']:.2f})"
|
||||
# )
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
return tokens
|
||||
|
||||
|
||||
def visualize_tokens(tokens):
|
||||
conversation = [{"speaker": -1, "text": ""}]
|
||||
for token in tokens:
|
||||
speaker = conversation[-1]['speaker']
|
||||
if token.speaker != speaker:
|
||||
conversation.append({"speaker": token.speaker, "text": token.text})
|
||||
else:
|
||||
conversation[-1]['text'] += token.text
|
||||
print("Conversation:")
|
||||
for entry in conversation:
|
||||
print(f"Speaker {entry['speaker']}: {entry['text']}")
|
||||
329
whisperlivekit/diarization/sortformer_backend.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from typing import List, Optional
|
||||
from queue import SimpleQueue, Empty
|
||||
|
||||
from whisperlivekit.timed_objects import SpeakerSegment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
|
||||
except ImportError:
|
||||
raise SystemExit("""Please use `pip install "git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]"` to use the Sortformer diarization""")
|
||||
|
||||
|
||||
class StreamingSortformerState:
|
||||
"""
|
||||
This class creates a class instance that will be used to store the state of the
|
||||
streaming Sortformer model.
|
||||
|
||||
Attributes:
|
||||
spkcache (torch.Tensor): Speaker cache to store embeddings from start
|
||||
spkcache_lengths (torch.Tensor): Lengths of the speaker cache
|
||||
spkcache_preds (torch.Tensor): The speaker predictions for the speaker cache parts
|
||||
fifo (torch.Tensor): FIFO queue to save the embedding from the latest chunks
|
||||
fifo_lengths (torch.Tensor): Lengths of the FIFO queue
|
||||
fifo_preds (torch.Tensor): The speaker predictions for the FIFO queue parts
|
||||
spk_perm (torch.Tensor): Speaker permutation information for the speaker cache
|
||||
mean_sil_emb (torch.Tensor): Mean silence embedding
|
||||
n_sil_frames (torch.Tensor): Number of silence frames
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.spkcache = None # Speaker cache to store embeddings from start
|
||||
self.spkcache_lengths = None
|
||||
self.spkcache_preds = None # speaker cache predictions
|
||||
self.fifo = None # to save the embedding from the latest chunks
|
||||
self.fifo_lengths = None
|
||||
self.fifo_preds = None
|
||||
self.spk_perm = None
|
||||
self.mean_sil_emb = None
|
||||
self.n_sil_frames = None
|
||||
|
||||
|
||||
class SortformerDiarization:
|
||||
def __init__(self, model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"):
|
||||
"""
|
||||
Stores the shared streaming Sortformer diarization model. Used when a new online_diarization is initialized.
|
||||
"""
|
||||
self._load_model(model_name)
|
||||
|
||||
def _load_model(self, model_name: str):
|
||||
"""Load and configure the Sortformer model for streaming."""
|
||||
try:
|
||||
self.diar_model = SortformerEncLabelModel.from_pretrained(model_name)
|
||||
self.diar_model.eval()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.diar_model.to(device)
|
||||
|
||||
## to test
|
||||
# for name, param in self.diar_model.named_parameters():
|
||||
# if param.device != device:
|
||||
# raise RuntimeError(f"Parameter {name} is on {param.device} but should be on {device}")
|
||||
|
||||
logger.info(f"Using {device.type.upper()} for Sortformer model")
|
||||
|
||||
self.diar_model.sortformer_modules.chunk_len = 10
|
||||
self.diar_model.sortformer_modules.subsampling_factor = 10
|
||||
self.diar_model.sortformer_modules.chunk_right_context = 0
|
||||
self.diar_model.sortformer_modules.chunk_left_context = 10
|
||||
self.diar_model.sortformer_modules.spkcache_len = 188
|
||||
self.diar_model.sortformer_modules.fifo_len = 188
|
||||
self.diar_model.sortformer_modules.spkcache_update_period = 144
|
||||
self.diar_model.sortformer_modules.log = False
|
||||
self.diar_model.sortformer_modules._check_streaming_parameters()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Sortformer model: {e}")
|
||||
raise
|
||||
|
||||
class SortformerDiarizationOnline:
|
||||
def __init__(self, shared_model, sample_rate: int = 16000):
|
||||
"""
|
||||
Initialize the streaming Sortformer diarization system.
|
||||
|
||||
Args:
|
||||
sample_rate: Audio sample rate (default: 16000)
|
||||
model_name: Pre-trained model name (default: "nvidia/diar_streaming_sortformer_4spk-v2")
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.diarization_segments = []
|
||||
self.diar_segments = []
|
||||
self.buffer_audio = np.array([], dtype=np.float32)
|
||||
self.segment_lock = threading.Lock()
|
||||
self.global_time_offset = 0.0
|
||||
self.debug = False
|
||||
|
||||
self.diar_model = shared_model.diar_model
|
||||
|
||||
self.audio2mel = AudioToMelSpectrogramPreprocessor(
|
||||
window_size=0.025,
|
||||
normalize="NA",
|
||||
n_fft=512,
|
||||
features=128,
|
||||
pad_to=0
|
||||
)
|
||||
self.audio2mel.to(self.diar_model.device)
|
||||
|
||||
self.chunk_duration_seconds = (
|
||||
self.diar_model.sortformer_modules.chunk_len *
|
||||
self.diar_model.sortformer_modules.subsampling_factor *
|
||||
self.diar_model.preprocessor._cfg.window_stride
|
||||
)
|
||||
|
||||
self._init_streaming_state()
|
||||
|
||||
self._previous_chunk_features = None
|
||||
self._chunk_index = 0
|
||||
self._len_prediction = None
|
||||
|
||||
# Audio buffer to store PCM chunks for debugging
|
||||
self.audio_buffer = []
|
||||
|
||||
# Buffer for accumulating audio chunks until reaching chunk_duration_seconds
|
||||
self.audio_chunk_buffer = []
|
||||
self.accumulated_duration = 0.0
|
||||
|
||||
logger.info("SortformerDiarization initialized successfully")
|
||||
|
||||
|
||||
def _init_streaming_state(self):
|
||||
"""Initialize the streaming state for the model."""
|
||||
batch_size = 1
|
||||
device = self.diar_model.device
|
||||
|
||||
self.streaming_state = StreamingSortformerState()
|
||||
self.streaming_state.spkcache = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.spkcache_preds = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.spkcache_len, self.diar_model.sortformer_modules.n_spk),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.spkcache_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.fifo = torch.zeros(
|
||||
(batch_size, self.diar_model.sortformer_modules.fifo_len, self.diar_model.sortformer_modules.fc_d_model),
|
||||
device=device
|
||||
)
|
||||
self.streaming_state.fifo_lengths = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.streaming_state.mean_sil_emb = torch.zeros((batch_size, self.diar_model.sortformer_modules.fc_d_model), device=device)
|
||||
self.streaming_state.n_sil_frames = torch.zeros((batch_size,), dtype=torch.long, device=device)
|
||||
self.total_preds = torch.zeros((batch_size, 0, self.diar_model.sortformer_modules.n_spk), device=device)
|
||||
|
||||
def insert_silence(self, silence_duration: Optional[float]):
|
||||
"""
|
||||
Insert silence period by adjusting the global time offset.
|
||||
|
||||
Args:
|
||||
silence_duration: Duration of silence in seconds
|
||||
"""
|
||||
with self.segment_lock:
|
||||
self.global_time_offset += silence_duration
|
||||
logger.debug(f"Inserted silence of {silence_duration:.2f}s, new offset: {self.global_time_offset:.2f}s")
|
||||
|
||||
def insert_audio_chunk(self, pcm_array: np.ndarray):
|
||||
if self.debug:
|
||||
self.audio_buffer.append(pcm_array.copy())
|
||||
self.buffer_audio = np.concatenate([self.buffer_audio, pcm_array.copy()])
|
||||
|
||||
|
||||
async def diarize(self):
|
||||
"""
|
||||
Process audio data for diarization in streaming fashion.
|
||||
|
||||
Args:
|
||||
pcm_array: Audio data as numpy array
|
||||
"""
|
||||
|
||||
threshold = int(self.chunk_duration_seconds * self.sample_rate)
|
||||
|
||||
if not len(self.buffer_audio) >= threshold:
|
||||
return []
|
||||
|
||||
audio = self.buffer_audio[:threshold]
|
||||
self.buffer_audio = self.buffer_audio[threshold:]
|
||||
|
||||
device = self.diar_model.device
|
||||
audio_signal_chunk = torch.tensor(audio, device=device).unsqueeze(0)
|
||||
audio_signal_length_chunk = torch.tensor([audio_signal_chunk.shape[1]], device=device)
|
||||
|
||||
processed_signal_chunk, processed_signal_length_chunk = self.audio2mel.get_features(
|
||||
audio_signal_chunk, audio_signal_length_chunk
|
||||
)
|
||||
processed_signal_chunk = processed_signal_chunk.to(device)
|
||||
processed_signal_length_chunk = processed_signal_length_chunk.to(device)
|
||||
|
||||
if self._previous_chunk_features is not None:
|
||||
to_add = self._previous_chunk_features[:, :, -99:].to(device)
|
||||
total_features = torch.concat([to_add, processed_signal_chunk], dim=2).to(device)
|
||||
else:
|
||||
total_features = processed_signal_chunk.to(device)
|
||||
|
||||
self._previous_chunk_features = processed_signal_chunk.to(device)
|
||||
|
||||
chunk_feat_seq_t = torch.transpose(total_features, 1, 2).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
left_offset = 8 if self._chunk_index > 0 else 0
|
||||
right_offset = 8
|
||||
|
||||
self.streaming_state, self.total_preds = self.diar_model.forward_streaming_step(
|
||||
processed_signal=chunk_feat_seq_t,
|
||||
processed_signal_length=torch.tensor([chunk_feat_seq_t.shape[1]]).to(device),
|
||||
streaming_state=self.streaming_state,
|
||||
total_preds=self.total_preds,
|
||||
left_offset=left_offset,
|
||||
right_offset=right_offset,
|
||||
)
|
||||
new_segments = self._process_predictions()
|
||||
|
||||
self._chunk_index += 1
|
||||
return new_segments
|
||||
|
||||
def _process_predictions(self):
|
||||
"""Process model predictions and convert to speaker segments."""
|
||||
preds_np = self.total_preds[0].cpu().numpy()
|
||||
active_speakers = np.argmax(preds_np, axis=1)
|
||||
|
||||
if self._len_prediction is None:
|
||||
self._len_prediction = len(active_speakers) #12
|
||||
|
||||
frame_duration = self.chunk_duration_seconds / self._len_prediction
|
||||
current_chunk_preds = active_speakers[-self._len_prediction:]
|
||||
|
||||
new_segments = []
|
||||
|
||||
with self.segment_lock:
|
||||
base_time = self._chunk_index * self.chunk_duration_seconds + self.global_time_offset
|
||||
current_spk = current_chunk_preds[0]
|
||||
start_time = round(base_time, 2)
|
||||
for idx, spk in enumerate(current_chunk_preds):
|
||||
current_time = round(base_time + idx * frame_duration, 2)
|
||||
if spk != current_spk:
|
||||
new_segments.append(SpeakerSegment(
|
||||
speaker=current_spk,
|
||||
start=start_time,
|
||||
end=current_time
|
||||
))
|
||||
start_time = current_time
|
||||
current_spk = spk
|
||||
new_segments.append(
|
||||
SpeakerSegment(
|
||||
speaker=current_spk,
|
||||
start=start_time,
|
||||
end=current_time
|
||||
)
|
||||
)
|
||||
return new_segments
|
||||
|
||||
def get_segments(self) -> List[SpeakerSegment]:
|
||||
"""Get a copy of the current speaker segments."""
|
||||
with self.segment_lock:
|
||||
return self.diarization_segments.copy()
|
||||
|
||||
def close(self):
|
||||
"""Close the diarization system and clean up resources."""
|
||||
logger.info("Closing SortformerDiarization")
|
||||
with self.segment_lock:
|
||||
self.diarization_segments.clear()
|
||||
|
||||
if self.debug:
|
||||
concatenated_audio = np.concatenate(self.audio_buffer)
|
||||
audio_data_int16 = (concatenated_audio * 32767).astype(np.int16)
|
||||
with wave.open("diarization_audio.wav", "wb") as wav_file:
|
||||
wav_file.setnchannels(1) # mono audio
|
||||
wav_file.setsampwidth(2) # 2 bytes per sample (int16)
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(audio_data_int16.tobytes())
|
||||
logger.info(f"Saved {len(concatenated_audio)} samples to diarization_audio.wav")
|
||||
|
||||
|
||||
def extract_number(s: str) -> int:
|
||||
"""Extract number from speaker string (compatibility function)."""
|
||||
import re
|
||||
m = re.search(r'\d+', s)
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import asyncio
|
||||
import librosa
|
||||
|
||||
async def main():
|
||||
"""TEST ONLY."""
|
||||
an4_audio = 'diarization_audio.wav'
|
||||
signal, sr = librosa.load(an4_audio, sr=16000)
|
||||
signal = signal[:16000*30]
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("ground truth:")
|
||||
print("Speaker 0: 0:00 - 0:09")
|
||||
print("Speaker 1: 0:09 - 0:19")
|
||||
print("Speaker 2: 0:19 - 0:25")
|
||||
print("Speaker 0: 0:25 - 0:30")
|
||||
print("=" * 50)
|
||||
|
||||
diarization_backend = SortformerDiarization()
|
||||
diarization = SortformerDiarizationOnline(shared_model = diarization_backend)
|
||||
chunk_size = 1600
|
||||
|
||||
for i in range(0, len(signal), chunk_size):
|
||||
chunk = signal[i:i+chunk_size]
|
||||
new_segments = await diarization.diarize(chunk)
|
||||
print(f"Processed chunk {i // chunk_size + 1}")
|
||||
print(new_segments)
|
||||
|
||||
segments = diarization.get_segments()
|
||||
print("\nDiarization results:")
|
||||
for segment in segments:
|
||||
print(f"Speaker {segment.speaker}: {segment.start:.2f}s - {segment.end:.2f}s")
|
||||
|
||||
asyncio.run(main())
|
||||
197
whisperlivekit/ffmpeg_manager.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Optional, Callable
|
||||
import contextlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
ERROR_INSTALL_INSTRUCTIONS = f"""
|
||||
{'='*50}
|
||||
FFmpeg is not installed or not found in your system's PATH.
|
||||
Alternative Solution: You can still use WhisperLiveKit without FFmpeg by adding the --pcm-input parameter. Note that when using this option, audio will not be compressed between the frontend and backend, which may result in higher bandwidth usage.
|
||||
|
||||
If you want to install FFmpeg:
|
||||
|
||||
# Ubuntu/Debian:
|
||||
sudo apt update && sudo apt install ffmpeg
|
||||
|
||||
# macOS (using Homebrew):
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows:
|
||||
# 1. Download the latest static build from https://ffmpeg.org/download.html
|
||||
# 2. Extract the archive (e.g., to C:\\FFmpeg).
|
||||
# 3. Add the 'bin' directory (e.g., C:\\FFmpeg\\bin) to your system's PATH environment variable.
|
||||
|
||||
After installation, please restart the application.
|
||||
{'='*50}
|
||||
"""
|
||||
|
||||
class FFmpegState(Enum):
|
||||
STOPPED = "stopped"
|
||||
STARTING = "starting"
|
||||
RUNNING = "running"
|
||||
RESTARTING = "restarting"
|
||||
FAILED = "failed"
|
||||
|
||||
class FFmpegManager:
|
||||
def __init__(self, sample_rate: int = 16000, channels: int = 1):
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
|
||||
self.process: Optional[asyncio.subprocess.Process] = None
|
||||
self._stderr_task: Optional[asyncio.Task] = None
|
||||
|
||||
self.on_error_callback: Optional[Callable[[str], None]] = None
|
||||
|
||||
self.state = FFmpegState.STOPPED
|
||||
self._state_lock = asyncio.Lock()
|
||||
|
||||
async def start(self) -> bool:
|
||||
async with self._state_lock:
|
||||
if self.state != FFmpegState.STOPPED:
|
||||
logger.warning(f"FFmpeg already running in state: {self.state}")
|
||||
return False
|
||||
self.state = FFmpegState.STARTING
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-hide_banner",
|
||||
"-loglevel", "error",
|
||||
"-i", "pipe:0",
|
||||
"-f", "s16le",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ac", str(self.channels),
|
||||
"-ar", str(self.sample_rate),
|
||||
"pipe:1"
|
||||
]
|
||||
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
self._stderr_task = asyncio.create_task(self._drain_stderr())
|
||||
|
||||
async with self._state_lock:
|
||||
self.state = FFmpegState.RUNNING
|
||||
|
||||
logger.info("FFmpeg started.")
|
||||
return True
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(ERROR_INSTALL_INSTRUCTIONS)
|
||||
async with self._state_lock:
|
||||
self.state = FFmpegState.FAILED
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("ffmpeg_not_found")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting FFmpeg: {e}")
|
||||
async with self._state_lock:
|
||||
self.state = FFmpegState.FAILED
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("start_failed")
|
||||
return False
|
||||
|
||||
async def stop(self):
|
||||
async with self._state_lock:
|
||||
if self.state == FFmpegState.STOPPED:
|
||||
return
|
||||
self.state = FFmpegState.STOPPED
|
||||
|
||||
if self.process:
|
||||
if self.process.stdin and not self.process.stdin.is_closing():
|
||||
self.process.stdin.close()
|
||||
await self.process.stdin.wait_closed()
|
||||
await self.process.wait()
|
||||
self.process = None
|
||||
|
||||
if self._stderr_task:
|
||||
self._stderr_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._stderr_task
|
||||
|
||||
logger.info("FFmpeg stopped.")
|
||||
|
||||
async def write_data(self, data: bytes) -> bool:
|
||||
async with self._state_lock:
|
||||
if self.state != FFmpegState.RUNNING:
|
||||
logger.warning(f"Cannot write, FFmpeg state: {self.state}")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.process.stdin.write(data)
|
||||
await self.process.stdin.drain()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing to FFmpeg: {e}")
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("write_error")
|
||||
return False
|
||||
|
||||
async def read_data(self, size: int) -> Optional[bytes]:
|
||||
async with self._state_lock:
|
||||
if self.state != FFmpegState.RUNNING:
|
||||
logger.warning(f"Cannot read, FFmpeg state: {self.state}")
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await asyncio.wait_for(
|
||||
self.process.stdout.read(size),
|
||||
timeout=20.0
|
||||
)
|
||||
return data
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("FFmpeg read timeout.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading from FFmpeg: {e}")
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("read_error")
|
||||
return None
|
||||
|
||||
async def get_state(self) -> FFmpegState:
|
||||
async with self._state_lock:
|
||||
return self.state
|
||||
|
||||
async def restart(self) -> bool:
|
||||
async with self._state_lock:
|
||||
if self.state == FFmpegState.RESTARTING:
|
||||
logger.warning("Restart already in progress.")
|
||||
return False
|
||||
self.state = FFmpegState.RESTARTING
|
||||
|
||||
logger.info("Restarting FFmpeg...")
|
||||
|
||||
try:
|
||||
await self.stop()
|
||||
await asyncio.sleep(1) # short delay before restarting
|
||||
return await self.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during FFmpeg restart: {e}")
|
||||
async with self._state_lock:
|
||||
self.state = FFmpegState.FAILED
|
||||
if self.on_error_callback:
|
||||
await self.on_error_callback("restart_failed")
|
||||
return False
|
||||
|
||||
async def _drain_stderr(self):
|
||||
try:
|
||||
while True:
|
||||
if not self.process or not self.process.stderr:
|
||||
break
|
||||
line = await self.process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
logger.debug(f"FFmpeg stderr: {line.decode(errors='ignore').strip()}")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("FFmpeg stderr drain task cancelled.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error draining FFmpeg stderr: {e}")
|
||||
0
whisperlivekit/local_agreement/__init__.py
Normal file
299
whisperlivekit/local_agreement/backends.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import sys
|
||||
import logging
|
||||
import io
|
||||
import soundfile as sf
|
||||
import math
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.whisper.transcribe import transcribe as whisper_transcribe
|
||||
logger = logging.getLogger(__name__)
|
||||
class ASRBase:
|
||||
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
||||
# "" for faster-whisper because it emits the spaces when needed)
|
||||
|
||||
def __init__(self, lan, model_size=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
if lan == "auto":
|
||||
self.original_language = None
|
||||
else:
|
||||
self.original_language = lan
|
||||
self.model = self.load_model(model_size, cache_dir, model_dir)
|
||||
|
||||
def with_offset(self, offset: float) -> ASRToken:
|
||||
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
||||
|
||||
def load_model(self, model_size, cache_dir, model_dir):
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
def use_vad(self):
|
||||
raise NotImplementedError("must be implemented in the child class")
|
||||
|
||||
|
||||
class WhisperASR(ASRBase):
|
||||
"""Uses WhisperLiveKit's built-in Whisper implementation."""
|
||||
sep = " "
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from whisperlivekit.whisper import load_model as load_model
|
||||
|
||||
if model_dir is not None:
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
if resolved_path.is_dir():
|
||||
pytorch_path, _, _ = model_path_and_type(resolved_path)
|
||||
if pytorch_path is None:
|
||||
raise FileNotFoundError(
|
||||
f"No supported PyTorch checkpoint found under {resolved_path}"
|
||||
)
|
||||
resolved_path = pytorch_path
|
||||
logger.debug(f"Loading Whisper model from custom path {resolved_path}")
|
||||
return load_model(str(resolved_path))
|
||||
|
||||
if model_size is None:
|
||||
raise ValueError("Either model_size or model_dir must be set for WhisperASR")
|
||||
|
||||
return load_model(model_size, download_root=cache_dir)
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
options = dict(self.transcribe_kargs)
|
||||
options.pop("vad", None)
|
||||
options.pop("vad_filter", None)
|
||||
language = self.original_language if self.original_language else None
|
||||
|
||||
result = whisper_transcribe(
|
||||
self.model,
|
||||
audio,
|
||||
language=language,
|
||||
initial_prompt=init_prompt,
|
||||
condition_on_previous_text=True,
|
||||
word_timestamps=True,
|
||||
**options,
|
||||
)
|
||||
return result
|
||||
|
||||
def ts_words(self, r) -> List[ASRToken]:
|
||||
"""
|
||||
Converts the Whisper result to a list of ASRToken objects.
|
||||
"""
|
||||
tokens = []
|
||||
for segment in r["segments"]:
|
||||
for word in segment["words"]:
|
||||
token = ASRToken(
|
||||
word["start"],
|
||||
word["end"],
|
||||
word["word"],
|
||||
probability=word.get("probability"),
|
||||
)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res) -> List[float]:
|
||||
return [segment["end"] for segment in res["segments"]]
|
||||
|
||||
def use_vad(self):
|
||||
logger.warning("VAD is not currently supported for WhisperASR backend and will be ignored.")
|
||||
|
||||
class FasterWhisperASR(ASRBase):
|
||||
"""Uses faster-whisper as the backend."""
|
||||
sep = ""
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
if model_dir is not None:
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading faster-whisper model from {resolved_path}. "
|
||||
f"model_size and cache_dir parameters are not used.")
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = model_size
|
||||
else:
|
||||
raise ValueError("Either model_size or model_dir must be set")
|
||||
device = "auto" # Allow CTranslate2 to decide available device
|
||||
compute_type = "auto" # Allow CTranslate2 to decide faster compute type
|
||||
|
||||
|
||||
model = WhisperModel(
|
||||
model_size_or_path,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
download_root=cache_dir,
|
||||
)
|
||||
return model
|
||||
|
||||
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
|
||||
segments, info = self.model.transcribe(
|
||||
audio,
|
||||
language=self.original_language,
|
||||
initial_prompt=init_prompt,
|
||||
beam_size=5,
|
||||
word_timestamps=True,
|
||||
condition_on_previous_text=True,
|
||||
**self.transcribe_kargs,
|
||||
)
|
||||
return list(segments)
|
||||
|
||||
def ts_words(self, segments) -> List[ASRToken]:
|
||||
tokens = []
|
||||
for segment in segments:
|
||||
if segment.no_speech_prob > 0.9:
|
||||
continue
|
||||
for word in segment.words:
|
||||
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, segments) -> List[float]:
|
||||
return [segment.end for segment in segments]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
|
||||
class MLXWhisper(ASRBase):
|
||||
"""
|
||||
Uses MLX Whisper optimized for Apple Silicon.
|
||||
"""
|
||||
sep = ""
|
||||
|
||||
def load_model(self, model_size=None, cache_dir=None, model_dir=None):
|
||||
from mlx_whisper.transcribe import ModelHolder, transcribe
|
||||
import mlx.core as mx
|
||||
|
||||
if model_dir is not None:
|
||||
resolved_path = resolve_model_path(model_dir)
|
||||
logger.debug(f"Loading MLX Whisper model from {resolved_path}. model_size parameter is not used.")
|
||||
model_size_or_path = str(resolved_path)
|
||||
elif model_size is not None:
|
||||
model_size_or_path = self.translate_model_name(model_size)
|
||||
logger.debug(f"Loading whisper model {model_size}. You use mlx whisper, so {model_size_or_path} will be used.")
|
||||
else:
|
||||
raise ValueError("Either model_size or model_dir must be set")
|
||||
|
||||
self.model_size_or_path = model_size_or_path
|
||||
dtype = mx.float16
|
||||
ModelHolder.get_model(model_size_or_path, dtype)
|
||||
return transcribe
|
||||
|
||||
def translate_model_name(self, model_name):
|
||||
model_mapping = {
|
||||
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||
"base.en": "mlx-community/whisper-base.en-mlx",
|
||||
"base": "mlx-community/whisper-base-mlx",
|
||||
"small.en": "mlx-community/whisper-small.en-mlx",
|
||||
"small": "mlx-community/whisper-small-mlx",
|
||||
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
||||
"medium": "mlx-community/whisper-medium-mlx",
|
||||
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
||||
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
||||
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
||||
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||
"large": "mlx-community/whisper-large-mlx",
|
||||
}
|
||||
mlx_model_path = model_mapping.get(model_name)
|
||||
if mlx_model_path:
|
||||
return mlx_model_path
|
||||
else:
|
||||
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
|
||||
|
||||
def transcribe(self, audio, init_prompt=""):
|
||||
if self.transcribe_kargs:
|
||||
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
|
||||
segments = self.model(
|
||||
audio,
|
||||
language=self.original_language,
|
||||
initial_prompt=init_prompt,
|
||||
word_timestamps=True,
|
||||
condition_on_previous_text=True,
|
||||
path_or_hf_repo=self.model_size_or_path,
|
||||
)
|
||||
return segments.get("segments", [])
|
||||
|
||||
def ts_words(self, segments) -> List[ASRToken]:
|
||||
tokens = []
|
||||
for segment in segments:
|
||||
if segment.get("no_speech_prob", 0) > 0.9:
|
||||
continue
|
||||
for word in segment.get("words", []):
|
||||
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
|
||||
tokens.append(token)
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res) -> List[float]:
|
||||
return [s["end"] for s in res]
|
||||
|
||||
def use_vad(self):
|
||||
self.transcribe_kargs["vad_filter"] = True
|
||||
|
||||
class OpenaiApiASR(ASRBase):
|
||||
"""Uses OpenAI's Whisper API for transcription."""
|
||||
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
||||
self.logfile = logfile
|
||||
self.modelname = "whisper-1"
|
||||
self.original_language = None if lan == "auto" else lan
|
||||
self.response_format = "verbose_json"
|
||||
self.temperature = temperature
|
||||
self.load_model()
|
||||
self.use_vad_opt = False
|
||||
self.direct_english_translation = False
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI()
|
||||
self.transcribed_seconds = 0
|
||||
|
||||
def ts_words(self, segments) -> List[ASRToken]:
|
||||
"""
|
||||
Converts OpenAI API response words into ASRToken objects while
|
||||
optionally skipping words that fall into no-speech segments.
|
||||
"""
|
||||
no_speech_segments = []
|
||||
if self.use_vad_opt:
|
||||
for segment in segments.segments:
|
||||
if segment.no_speech_prob > 0.8:
|
||||
no_speech_segments.append((segment.start, segment.end))
|
||||
tokens = []
|
||||
for word in segments.words:
|
||||
start = word.start
|
||||
end = word.end
|
||||
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
||||
continue
|
||||
tokens.append(ASRToken(start, end, word.word))
|
||||
return tokens
|
||||
|
||||
def segments_end_ts(self, res) -> List[float]:
|
||||
return [s.end for s in res.words]
|
||||
|
||||
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
||||
buffer = io.BytesIO()
|
||||
buffer.name = "temp.wav"
|
||||
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
|
||||
buffer.seek(0)
|
||||
self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
|
||||
params = {
|
||||
"model": self.modelname,
|
||||
"file": buffer,
|
||||
"response_format": self.response_format,
|
||||
"temperature": self.temperature,
|
||||
"timestamp_granularities": ["word", "segment"],
|
||||
}
|
||||
if not self.direct_english_translation and self.original_language:
|
||||
params["language"] = self.original_language
|
||||
if prompt:
|
||||
params["prompt"] = prompt
|
||||
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
|
||||
transcript = proc.create(**params)
|
||||
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
||||
return transcript
|
||||
|
||||
def use_vad(self):
|
||||
self.use_vad_opt = True
|
||||
421
whisperlivekit/local_agreement/online_asr.py
Normal file
@@ -0,0 +1,421 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
from whisperlivekit.timed_objects import ASRToken, Sentence, Transcript
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class HypothesisBuffer:
|
||||
"""
|
||||
Buffer to store and process ASR hypothesis tokens.
|
||||
|
||||
It holds:
|
||||
- committed_in_buffer: tokens that have been confirmed (committed)
|
||||
- buffer: the last hypothesis that is not yet committed
|
||||
- new: new tokens coming from the recognizer
|
||||
"""
|
||||
def __init__(self, logfile=sys.stderr, confidence_validation=False):
|
||||
self.confidence_validation = confidence_validation
|
||||
self.committed_in_buffer: List[ASRToken] = []
|
||||
self.buffer: List[ASRToken] = []
|
||||
self.new: List[ASRToken] = []
|
||||
self.last_committed_time = 0.0
|
||||
self.last_committed_word: Optional[str] = None
|
||||
self.logfile = logfile
|
||||
|
||||
def insert(self, new_tokens: List[ASRToken], offset: float):
|
||||
"""
|
||||
Insert new tokens (after applying a time offset) and compare them with the
|
||||
already committed tokens. Only tokens that extend the committed hypothesis
|
||||
are added.
|
||||
"""
|
||||
# Apply the offset to each token.
|
||||
new_tokens = [token.with_offset(offset) for token in new_tokens]
|
||||
# Only keep tokens that are roughly “new”
|
||||
self.new = [token for token in new_tokens if token.start > self.last_committed_time - 0.1]
|
||||
|
||||
if self.new:
|
||||
first_token = self.new[0]
|
||||
if abs(first_token.start - self.last_committed_time) < 1:
|
||||
if self.committed_in_buffer:
|
||||
committed_len = len(self.committed_in_buffer)
|
||||
new_len = len(self.new)
|
||||
# Try to match 1 to 5 consecutive tokens
|
||||
max_ngram = min(min(committed_len, new_len), 5)
|
||||
for i in range(1, max_ngram + 1):
|
||||
committed_ngram = " ".join(token.text for token in self.committed_in_buffer[-i:])
|
||||
new_ngram = " ".join(token.text for token in self.new[:i])
|
||||
if committed_ngram == new_ngram:
|
||||
removed = []
|
||||
for _ in range(i):
|
||||
removed_token = self.new.pop(0)
|
||||
removed.append(repr(removed_token))
|
||||
logger.debug(f"Removing last {i} words: {' '.join(removed)}")
|
||||
break
|
||||
|
||||
def flush(self) -> List[ASRToken]:
|
||||
"""
|
||||
Returns the committed chunk, defined as the longest common prefix
|
||||
between the previous hypothesis and the new tokens.
|
||||
"""
|
||||
committed: List[ASRToken] = []
|
||||
while self.new:
|
||||
current_new = self.new[0]
|
||||
if self.confidence_validation and current_new.probability and current_new.probability > 0.95:
|
||||
committed.append(current_new)
|
||||
self.last_committed_word = current_new.text
|
||||
self.last_committed_time = current_new.end
|
||||
self.new.pop(0)
|
||||
self.buffer.pop(0) if self.buffer else None
|
||||
elif not self.buffer:
|
||||
break
|
||||
elif current_new.text == self.buffer[0].text:
|
||||
committed.append(current_new)
|
||||
self.last_committed_word = current_new.text
|
||||
self.last_committed_time = current_new.end
|
||||
self.buffer.pop(0)
|
||||
self.new.pop(0)
|
||||
else:
|
||||
break
|
||||
self.buffer = self.new
|
||||
self.new = []
|
||||
self.committed_in_buffer.extend(committed)
|
||||
return committed
|
||||
|
||||
def pop_committed(self, time: float):
|
||||
"""
|
||||
Remove tokens (from the beginning) that have ended before `time`.
|
||||
"""
|
||||
while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
|
||||
self.committed_in_buffer.pop(0)
|
||||
|
||||
|
||||
|
||||
class OnlineASRProcessor:
|
||||
"""
|
||||
Processes incoming audio in a streaming fashion, calling the ASR system
|
||||
periodically, and uses a hypothesis buffer to commit and trim recognized text.
|
||||
|
||||
The processor supports two types of buffer trimming:
|
||||
- "sentence": trims at sentence boundaries (using a sentence tokenizer)
|
||||
- "segment": trims at fixed segment durations.
|
||||
"""
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
"""
|
||||
asr: An ASR system object (for example, a WhisperASR instance) that
|
||||
provides a `transcribe` method, a `ts_words` method (to extract tokens),
|
||||
a `segments_end_ts` method, and a separator attribute `sep`.
|
||||
tokenize_method: A function that receives text and returns a list of sentence strings.
|
||||
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
|
||||
"""
|
||||
self.asr = asr
|
||||
self.tokenize = asr.tokenizer
|
||||
self.logfile = logfile
|
||||
self.confidence_validation = asr.confidence_validation
|
||||
self.global_time_offset = 0.0
|
||||
self.init()
|
||||
|
||||
self.buffer_trimming_way = asr.buffer_trimming
|
||||
self.buffer_trimming_sec = asr.buffer_trimming_sec
|
||||
|
||||
if self.buffer_trimming_way not in ["sentence", "segment"]:
|
||||
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
|
||||
if self.buffer_trimming_sec <= 0:
|
||||
raise ValueError("buffer_trimming_sec must be positive")
|
||||
elif self.buffer_trimming_sec > 30:
|
||||
logger.warning(
|
||||
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
|
||||
)
|
||||
|
||||
def init(self, offset: Optional[float] = None):
|
||||
"""Initialize or reset the processing buffers."""
|
||||
self.audio_buffer = np.array([], dtype=np.float32)
|
||||
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile, confidence_validation=self.confidence_validation)
|
||||
self.buffer_time_offset = offset if offset is not None else 0.0
|
||||
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
||||
self.committed: List[ASRToken] = []
|
||||
self.time_of_last_asr_output = 0.0
|
||||
|
||||
def get_audio_buffer_end_time(self) -> float:
|
||||
"""Returns the absolute end time of the current audio_buffer."""
|
||||
return self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE)
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: Optional[float] = None):
|
||||
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
||||
self.audio_buffer = np.append(self.audio_buffer, audio)
|
||||
|
||||
def start_silence(self):
|
||||
if self.audio_buffer.size == 0:
|
||||
return [], self.get_audio_buffer_end_time()
|
||||
return self.process_iter()
|
||||
|
||||
def end_silence(self, silence_duration: Optional[float], offset: float):
|
||||
if not silence_duration or silence_duration <= 0:
|
||||
return
|
||||
|
||||
long_silence = silence_duration >= 5
|
||||
if not long_silence:
|
||||
gap_samples = int(self.SAMPLING_RATE * silence_duration)
|
||||
if gap_samples > 0:
|
||||
gap_silence = np.zeros(gap_samples, dtype=np.float32)
|
||||
self.insert_audio_chunk(gap_silence)
|
||||
else:
|
||||
self.init(offset=silence_duration + offset)
|
||||
|
||||
self.global_time_offset += silence_duration
|
||||
|
||||
def insert_silence(self, silence_duration, offset):
|
||||
"""
|
||||
Backwards compatibility shim for legacy callers that still use insert_silence.
|
||||
"""
|
||||
self.end_silence(silence_duration, offset)
|
||||
|
||||
def prompt(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns a tuple: (prompt, context), where:
|
||||
- prompt is a 200-character suffix of committed text that falls
|
||||
outside the current audio buffer.
|
||||
- context is the committed text within the current audio buffer.
|
||||
"""
|
||||
k = len(self.committed)
|
||||
while k > 0 and self.committed[k - 1].end > self.buffer_time_offset:
|
||||
k -= 1
|
||||
|
||||
prompt_tokens = self.committed[:k]
|
||||
prompt_words = [token.text for token in prompt_tokens]
|
||||
prompt_list = []
|
||||
length_count = 0
|
||||
# Use the last words until reaching 200 characters.
|
||||
while prompt_words and length_count < 200:
|
||||
word = prompt_words.pop(-1)
|
||||
length_count += len(word) + 1
|
||||
prompt_list.append(word)
|
||||
non_prompt_tokens = self.committed[k:]
|
||||
context_text = self.asr.sep.join(token.text for token in non_prompt_tokens)
|
||||
return self.asr.sep.join(prompt_list[::-1]), context_text
|
||||
|
||||
def get_buffer(self):
|
||||
"""
|
||||
Get the unvalidated buffer in string format.
|
||||
"""
|
||||
return self.concatenate_tokens(self.transcript_buffer.buffer)
|
||||
|
||||
|
||||
def process_iter(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Processes the current audio buffer.
|
||||
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
current_audio_processed_upto = self.get_audio_buffer_end_time()
|
||||
prompt_text, _ = self.prompt()
|
||||
logger.debug(
|
||||
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
|
||||
)
|
||||
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
|
||||
tokens = self.asr.ts_words(res)
|
||||
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
|
||||
committed_tokens = self.transcript_buffer.flush()
|
||||
self.committed.extend(committed_tokens)
|
||||
|
||||
if committed_tokens:
|
||||
self.time_of_last_asr_output = self.committed[-1].end
|
||||
|
||||
completed = self.concatenate_tokens(committed_tokens)
|
||||
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
|
||||
incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
|
||||
logger.debug(f"INCOMPLETE: {incomp.text}")
|
||||
|
||||
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
if not committed_tokens and buffer_duration > self.buffer_trimming_sec:
|
||||
time_since_last_output = self.get_audio_buffer_end_time() - self.time_of_last_asr_output
|
||||
if time_since_last_output > self.buffer_trimming_sec:
|
||||
logger.warning(
|
||||
f"No ASR output for {time_since_last_output:.2f}s. "
|
||||
f"Resetting buffer to prevent freezing."
|
||||
)
|
||||
self.init(offset=self.get_audio_buffer_end_time())
|
||||
return [], current_audio_processed_upto
|
||||
|
||||
if committed_tokens and self.buffer_trimming_way == "sentence":
|
||||
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
|
||||
self.chunk_completed_sentence()
|
||||
|
||||
s = self.buffer_trimming_sec if self.buffer_trimming_way == "segment" else 30
|
||||
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
|
||||
self.chunk_completed_segment(res)
|
||||
logger.debug("Chunking segment")
|
||||
logger.debug(
|
||||
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
||||
)
|
||||
if self.global_time_offset:
|
||||
for token in committed_tokens:
|
||||
token = token.with_offset(self.global_time_offset)
|
||||
return committed_tokens, current_audio_processed_upto
|
||||
|
||||
def chunk_completed_sentence(self):
|
||||
"""
|
||||
If the committed tokens form at least two sentences, chunk the audio
|
||||
buffer at the end time of the penultimate sentence.
|
||||
Also ensures chunking happens if audio buffer exceeds a time limit.
|
||||
"""
|
||||
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
if not self.committed:
|
||||
if buffer_duration > self.buffer_trimming_sec:
|
||||
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
|
||||
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
|
||||
self.chunk_at(chunk_time)
|
||||
return
|
||||
|
||||
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
|
||||
sentences = self.words_to_sentences(self.committed)
|
||||
for sentence in sentences:
|
||||
logger.debug(f"\tSentence: {sentence.text}")
|
||||
|
||||
chunk_done = False
|
||||
if len(sentences) >= 2:
|
||||
while len(sentences) > 2:
|
||||
sentences.pop(0)
|
||||
chunk_time = sentences[-2].end
|
||||
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
|
||||
self.chunk_at(chunk_time)
|
||||
chunk_done = True
|
||||
|
||||
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
|
||||
last_committed_time = self.committed[-1].end
|
||||
logger.debug(f"--- Not enough sentences, chunking at last committed time {last_committed_time:.2f}")
|
||||
self.chunk_at(last_committed_time)
|
||||
|
||||
def chunk_completed_segment(self, res):
|
||||
"""
|
||||
Chunk the audio buffer based on segment-end timestamps reported by the ASR.
|
||||
Also ensures chunking happens if audio buffer exceeds a time limit.
|
||||
"""
|
||||
buffer_duration = len(self.audio_buffer) / self.SAMPLING_RATE
|
||||
if not self.committed:
|
||||
if buffer_duration > self.buffer_trimming_sec:
|
||||
chunk_time = self.buffer_time_offset + (buffer_duration / 2)
|
||||
logger.debug(f"--- No speech detected, forced chunking at {chunk_time:.2f}")
|
||||
self.chunk_at(chunk_time)
|
||||
return
|
||||
|
||||
logger.debug("Processing committed tokens for segmenting")
|
||||
ends = self.asr.segments_end_ts(res)
|
||||
last_committed_time = self.committed[-1].end
|
||||
chunk_done = False
|
||||
if len(ends) > 1:
|
||||
logger.debug("Multiple segments available for chunking")
|
||||
e = ends[-2] + self.buffer_time_offset
|
||||
while len(ends) > 2 and e > last_committed_time:
|
||||
ends.pop(-1)
|
||||
e = ends[-2] + self.buffer_time_offset
|
||||
if e <= last_committed_time:
|
||||
logger.debug(f"--- Segment chunked at {e:.2f}")
|
||||
self.chunk_at(e)
|
||||
chunk_done = True
|
||||
else:
|
||||
logger.debug("--- Last segment not within committed area")
|
||||
else:
|
||||
logger.debug("--- Not enough segments to chunk")
|
||||
|
||||
if not chunk_done and buffer_duration > self.buffer_trimming_sec:
|
||||
logger.debug(f"--- Buffer too large, chunking at last committed time {last_committed_time:.2f}")
|
||||
self.chunk_at(last_committed_time)
|
||||
|
||||
logger.debug("Segment chunking complete")
|
||||
|
||||
def chunk_at(self, time: float):
|
||||
"""
|
||||
Trim both the hypothesis and audio buffer at the given time.
|
||||
"""
|
||||
logger.debug(f"Chunking at {time:.2f}s")
|
||||
logger.debug(
|
||||
f"Audio buffer length before chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
|
||||
)
|
||||
self.transcript_buffer.pop_committed(time)
|
||||
cut_seconds = time - self.buffer_time_offset
|
||||
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):]
|
||||
self.buffer_time_offset = time
|
||||
logger.debug(
|
||||
f"Audio buffer length after chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
|
||||
)
|
||||
|
||||
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
||||
"""
|
||||
Converts a list of tokens to a list of Sentence objects using the provided
|
||||
sentence tokenizer.
|
||||
"""
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
full_text = " ".join(token.text for token in tokens)
|
||||
|
||||
if self.tokenize:
|
||||
try:
|
||||
sentence_texts = self.tokenize(full_text)
|
||||
except Exception as e:
|
||||
# Some tokenizers (e.g., MosesSentenceSplitter) expect a list input.
|
||||
try:
|
||||
sentence_texts = self.tokenize([full_text])
|
||||
except Exception as e2:
|
||||
raise ValueError("Tokenization failed") from e2
|
||||
else:
|
||||
sentence_texts = [full_text]
|
||||
|
||||
sentences: List[Sentence] = []
|
||||
token_index = 0
|
||||
for sent_text in sentence_texts:
|
||||
sent_text = sent_text.strip()
|
||||
if not sent_text:
|
||||
continue
|
||||
sent_tokens = []
|
||||
accumulated = ""
|
||||
# Accumulate tokens until roughly matching the length of the sentence text.
|
||||
while token_index < len(tokens) and len(accumulated) < len(sent_text):
|
||||
token = tokens[token_index]
|
||||
accumulated = (accumulated + " " + token.text).strip() if accumulated else token.text
|
||||
sent_tokens.append(token)
|
||||
token_index += 1
|
||||
if sent_tokens:
|
||||
sentence = Sentence(
|
||||
start=sent_tokens[0].start,
|
||||
end=sent_tokens[-1].end,
|
||||
text=" ".join(t.text for t in sent_tokens),
|
||||
)
|
||||
sentences.append(sentence)
|
||||
return sentences
|
||||
|
||||
def finish(self) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Flush the remaining transcript when processing ends.
|
||||
Returns a tuple: (list of remaining ASRToken objects, float representing the final audio processed up to time).
|
||||
"""
|
||||
remaining_tokens = self.transcript_buffer.buffer
|
||||
logger.debug(f"Final non-committed tokens: {remaining_tokens}")
|
||||
final_processed_upto = self.buffer_time_offset + (len(self.audio_buffer) / self.SAMPLING_RATE)
|
||||
self.buffer_time_offset = final_processed_upto
|
||||
return remaining_tokens, final_processed_upto
|
||||
|
||||
def concatenate_tokens(
|
||||
self,
|
||||
tokens: List[ASRToken],
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> Transcript:
|
||||
sep = sep if sep is not None else self.asr.sep
|
||||
text = sep.join(token.text for token in tokens)
|
||||
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return Transcript(start, end, text, probability=probability)
|
||||
199
whisperlivekit/local_agreement/whisper_online.py
Normal file
@@ -0,0 +1,199 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import numpy as np
|
||||
import librosa
|
||||
from functools import lru_cache
|
||||
import time
|
||||
import logging
|
||||
import platform
|
||||
from .backends import FasterWhisperASR, MLXWhisper, WhisperASR, OpenaiApiASR
|
||||
from whisperlivekit.warmup import warmup_asr
|
||||
from whisperlivekit.model_paths import resolve_model_path, model_path_and_type
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
||||
","
|
||||
)
|
||||
|
||||
|
||||
def create_tokenizer(lan):
|
||||
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
||||
|
||||
assert (
|
||||
lan in WHISPER_LANG_CODES
|
||||
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
||||
|
||||
if lan == "uk":
|
||||
import tokenize_uk
|
||||
|
||||
class UkrainianTokenizer:
|
||||
def split(self, text):
|
||||
return tokenize_uk.tokenize_sents(text)
|
||||
|
||||
return UkrainianTokenizer()
|
||||
|
||||
# supported by fast-mosestokenizer
|
||||
if (
|
||||
lan
|
||||
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
||||
):
|
||||
from mosestokenizer import MosesSentenceSplitter
|
||||
|
||||
return MosesSentenceSplitter(lan)
|
||||
|
||||
# the following languages are in Whisper, but not in wtpsplit:
|
||||
if (
|
||||
lan
|
||||
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
||||
):
|
||||
logger.debug(
|
||||
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
||||
)
|
||||
lan = None
|
||||
|
||||
from wtpsplit import WtP
|
||||
|
||||
# downloads the model from huggingface on the first use
|
||||
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
||||
|
||||
class WtPtok:
|
||||
def split(self, sent):
|
||||
return wtp.split(sent, lang_code=lan)
|
||||
|
||||
return WtPtok()
|
||||
|
||||
|
||||
def backend_factory(
|
||||
backend,
|
||||
lan,
|
||||
model_size,
|
||||
model_cache_dir,
|
||||
model_dir,
|
||||
model_path,
|
||||
direct_english_translation,
|
||||
buffer_trimming,
|
||||
buffer_trimming_sec,
|
||||
confidence_validation,
|
||||
warmup_file=None,
|
||||
min_chunk_size=None,
|
||||
):
|
||||
backend_choice = backend
|
||||
custom_reference = model_path or model_dir
|
||||
resolved_root = None
|
||||
pytorch_checkpoint = None
|
||||
has_mlx_weights = False
|
||||
has_fw_weights = False
|
||||
|
||||
if custom_reference:
|
||||
resolved_root = resolve_model_path(custom_reference)
|
||||
if resolved_root.is_dir():
|
||||
pytorch_checkpoint, has_mlx_weights, has_fw_weights = model_path_and_type(resolved_root)
|
||||
else:
|
||||
pytorch_checkpoint = resolved_root
|
||||
|
||||
if backend_choice == "openai-api":
|
||||
logger.debug("Using OpenAI API.")
|
||||
asr = OpenaiApiASR(lan=lan)
|
||||
else:
|
||||
backend_choice = _normalize_backend_choice(
|
||||
backend_choice,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
)
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
asr_cls = FasterWhisperASR
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("Faster-Whisper backend expects a directory with CTranslate2 weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
elif backend_choice == "mlx-whisper":
|
||||
asr_cls = MLXWhisper
|
||||
if resolved_root is not None and not resolved_root.is_dir():
|
||||
raise ValueError("MLX Whisper backend expects a directory containing MLX weights.")
|
||||
model_override = str(resolved_root) if resolved_root is not None else None
|
||||
else:
|
||||
asr_cls = WhisperASR
|
||||
model_override = str(pytorch_checkpoint) if pytorch_checkpoint is not None else None
|
||||
if custom_reference and model_override is None:
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint found under {resolved_root or custom_reference}"
|
||||
)
|
||||
|
||||
t = time.time()
|
||||
logger.info(f"Loading Whisper {model_size} model for language {lan} using backend {backend_choice}...")
|
||||
asr = asr_cls(
|
||||
model_size=model_size,
|
||||
lan=lan,
|
||||
cache_dir=model_cache_dir,
|
||||
model_dir=model_override,
|
||||
)
|
||||
e = time.time()
|
||||
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
||||
|
||||
if direct_english_translation:
|
||||
tgt_language = "en" # Whisper translates into English
|
||||
else:
|
||||
tgt_language = lan # Whisper transcribes in this language
|
||||
|
||||
# Create the tokenizer
|
||||
if buffer_trimming == "sentence":
|
||||
tokenizer = create_tokenizer(tgt_language)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
warmup_asr(asr, warmup_file)
|
||||
|
||||
asr.confidence_validation = confidence_validation
|
||||
asr.tokenizer = tokenizer
|
||||
asr.buffer_trimming = buffer_trimming
|
||||
asr.buffer_trimming_sec = buffer_trimming_sec
|
||||
asr.backend_choice = backend_choice
|
||||
return asr
|
||||
|
||||
|
||||
def _normalize_backend_choice(
|
||||
preferred_backend,
|
||||
resolved_root,
|
||||
has_mlx_weights,
|
||||
has_fw_weights,
|
||||
):
|
||||
backend_choice = preferred_backend
|
||||
|
||||
if backend_choice == "auto":
|
||||
if mlx_backend_available(warn_on_missing=True) and (resolved_root is None or has_mlx_weights):
|
||||
return "mlx-whisper"
|
||||
if faster_backend_available(warn_on_missing=True) and (resolved_root is None or has_fw_weights):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
if backend_choice == "mlx-whisper":
|
||||
if not mlx_backend_available():
|
||||
raise RuntimeError("mlx-whisper backend requested but mlx-whisper is not installed.")
|
||||
if resolved_root is not None and not has_mlx_weights:
|
||||
raise FileNotFoundError(
|
||||
f"mlx-whisper backend requested but no MLX weights were found under {resolved_root}"
|
||||
)
|
||||
if platform.system() != "Darwin":
|
||||
logger.warning("mlx-whisper backend requested on a non-macOS system; this may fail.")
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "faster-whisper":
|
||||
if not faster_backend_available():
|
||||
raise RuntimeError("faster-whisper backend requested but faster-whisper is not installed.")
|
||||
if resolved_root is not None and not has_fw_weights:
|
||||
raise FileNotFoundError(
|
||||
f"faster-whisper backend requested but no Faster-Whisper weights were found under {resolved_root}"
|
||||
)
|
||||
return backend_choice
|
||||
|
||||
if backend_choice == "whisper":
|
||||
return backend_choice
|
||||
|
||||
raise ValueError(f"Unknown backend '{preferred_backend}' for LocalAgreement.")
|
||||
69
whisperlivekit/model_paths.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
def model_path_and_type(model_path: Union[str, Path]) -> Tuple[Optional[Path], bool, bool]:
|
||||
"""
|
||||
Inspect the provided path and determine which model formats are available.
|
||||
|
||||
Returns:
|
||||
pytorch_path: Path to a PyTorch checkpoint (if present).
|
||||
compatible_whisper_mlx: True if MLX weights exist in this folder.
|
||||
compatible_faster_whisper: True if Faster-Whisper (ctranslate2) weights exist.
|
||||
"""
|
||||
path = Path(model_path)
|
||||
|
||||
compatible_whisper_mlx = False
|
||||
compatible_faster_whisper = False
|
||||
pytorch_path: Optional[Path] = None
|
||||
|
||||
if path.is_file() and path.suffix.lower() in [".pt", ".safetensors", ".bin"]:
|
||||
pytorch_path = path
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
if path.is_dir():
|
||||
for file in path.iterdir():
|
||||
if not file.is_file():
|
||||
continue
|
||||
|
||||
filename = file.name.lower()
|
||||
suffix = file.suffix.lower()
|
||||
|
||||
if filename in {"weights.npz", "weights.safetensors"}:
|
||||
compatible_whisper_mlx = True
|
||||
elif filename in {"model.bin", "encoder.bin", "decoder.bin"}:
|
||||
compatible_faster_whisper = True
|
||||
elif suffix in {".pt", ".safetensors"}:
|
||||
pytorch_path = file
|
||||
elif filename == "pytorch_model.bin":
|
||||
pytorch_path = file
|
||||
|
||||
if pytorch_path is None:
|
||||
fallback = path / "pytorch_model.bin"
|
||||
if fallback.exists():
|
||||
pytorch_path = fallback
|
||||
|
||||
return pytorch_path, compatible_whisper_mlx, compatible_faster_whisper
|
||||
|
||||
|
||||
def resolve_model_path(model_path: Union[str, Path]) -> Path:
|
||||
"""
|
||||
Return a local path for the provided model reference.
|
||||
|
||||
If the path does not exist locally, it is treated as a Hugging Face repo id
|
||||
and downloaded via snapshot_download.
|
||||
"""
|
||||
path = Path(model_path).expanduser()
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as exc: # pragma: no cover - optional dependency guard
|
||||
raise FileNotFoundError(
|
||||
f"Model path '{model_path}' does not exist locally and huggingface_hub "
|
||||
"is not installed to download it."
|
||||
) from exc
|
||||
|
||||
downloaded_path = Path(snapshot_download(repo_id=str(model_path)))
|
||||
return downloaded_path
|
||||
332
whisperlivekit/parse_args.py
Normal file
@@ -0,0 +1,332 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description="Whisper FastAPI Online Server")
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="The host address to bind the server to.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="The port number to bind the server to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-file",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="warmup_file",
|
||||
help="""
|
||||
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
||||
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
||||
If empty, no warmup is performed.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence-validation",
|
||||
action="store_true",
|
||||
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable speaker diarization.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--punctuation-split",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use punctuation marks from transcription to improve speaker boundary detection. Requires both transcription and diarization to be enabled.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--segmentation-model",
|
||||
type=str,
|
||||
default="pyannote/segmentation-3.0",
|
||||
help="Hugging Face model ID for pyannote.audio segmentation model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default="pyannote/embedding",
|
||||
help="Hugging Face model ID for pyannote.audio embedding model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diarization-backend",
|
||||
type=str,
|
||||
default="sortformer",
|
||||
choices=["sortformer", "diart"],
|
||||
help="The diarization backend to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-transcription",
|
||||
action="store_true",
|
||||
help="Disable transcription to only see live diarization results.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-punctuation-split",
|
||||
action="store_true",
|
||||
help="Disable the split parameter.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-chunk-size",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="base",
|
||||
dest='model_size',
|
||||
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Overriding the default model cache dir where models downloaded from the hub are saved",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lan",
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
dest='lan',
|
||||
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--direct-english-translation",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use Whisper to directly translate to english.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-language",
|
||||
type=str,
|
||||
default="",
|
||||
dest="target_language",
|
||||
help="Target language for translation. Not functional yet.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--backend-policy",
|
||||
type=str,
|
||||
default="simulstreaming",
|
||||
choices=["1", "2", "simulstreaming", "localagreement"],
|
||||
help="Select the streaming policy: 1 or 'simulstreaming' for AlignAtt, 2 or 'localagreement' for LocalAgreement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "mlx-whisper", "faster-whisper", "whisper", "openai-api"],
|
||||
help="Select the Whisper backend implementation (auto: prefer MLX on macOS, otherwise Faster-Whisper, else Whisper). Use 'openai-api' with --backend-policy localagreement to call OpenAI's API.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-vac",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable VAC = voice activity controller.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-vad",
|
||||
action="store_true",
|
||||
help="Disable VAD (voice activity detection).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--buffer_trimming",
|
||||
type=str,
|
||||
default="segment",
|
||||
choices=["sentence", "segment"],
|
||||
help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer_trimming_sec",
|
||||
type=float,
|
||||
default=15,
|
||||
help="Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--log-level",
|
||||
dest="log_level",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the log level",
|
||||
default="DEBUG",
|
||||
)
|
||||
parser.add_argument("--ssl-certfile", type=str, help="Path to the SSL certificate file.", default=None)
|
||||
parser.add_argument("--ssl-keyfile", type=str, help="Path to the SSL private key file.", default=None)
|
||||
parser.add_argument("--forwarded-allow-ips", type=str, help="Allowed ips for reverse proxying.", default=None)
|
||||
parser.add_argument(
|
||||
"--pcm-input",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="If set, raw PCM (s16le) data is expected as input and FFmpeg will be bypassed. Frontend will use AudioWorklet instead of MediaRecorder."
|
||||
)
|
||||
# SimulStreaming-specific arguments
|
||||
simulstreaming_group = parser.add_argument_group('SimulStreaming arguments (only used with --backend simulstreaming)')
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--disable-fast-encoder",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="disable_fast_encoder",
|
||||
help="Disable Faster Whisper or MLX Whisper backends for encoding (if installed). Slower but helpful when GPU memory is limited",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--custom-alignment-heads",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Use your own alignment heads, useful when `--model-dir` is used",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--frame-threshold",
|
||||
type=int,
|
||||
default=25,
|
||||
dest="frame_threshold",
|
||||
help="Threshold for the attention-guided decoding. The AlignAtt policy will decode only until this number of frames from the end of audio. In frames: one frame is 0.02 seconds for large-v3 model.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--beams",
|
||||
"-b",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of beams for beam search decoding. If 1, GreedyDecoder is used.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--decoder",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="decoder_type",
|
||||
choices=["beam", "greedy"],
|
||||
help="Override automatic selection of beam or greedy decoder. If beams > 1 and greedy: invalid.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--audio-max-len",
|
||||
type=float,
|
||||
default=30.0,
|
||||
dest="audio_max_len",
|
||||
help="Max length of the audio buffer, in seconds.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--audio-min-len",
|
||||
type=float,
|
||||
default=0.0,
|
||||
dest="audio_min_len",
|
||||
help="Skip processing if the audio buffer is shorter than this length, in seconds. Useful when the --min-chunk-size is small.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--cif-ckpt-path",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="cif_ckpt_path",
|
||||
help="The file path to the Simul-Whisper's CIF model checkpoint that detects whether there is end of word at the end of the chunk. If not, the last decoded space-separated word is truncated because it is often wrong -- transcribing a word in the middle. The CIF model adapted for the Whisper model version should be used. Find the models in https://github.com/backspacetg/simul_whisper/tree/main/cif_models . Note that there is no model for large-v3.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--never-fire",
|
||||
action="store_true",
|
||||
default=False,
|
||||
dest="never_fire",
|
||||
help="Override the CIF model. If True, the last word is NEVER truncated, no matter what the CIF model detects. If False: if CIF model path is set, the last word is SOMETIMES truncated, depending on the CIF detection. Otherwise, if the CIF model path is not set, the last word is ALWAYS trimmed.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--init-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="init_prompt",
|
||||
help="Init prompt for the model. It should be in the target language.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--static-init-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="static_init_prompt",
|
||||
help="Do not scroll over this text. It can contain terminology that should be relevant over all document.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--max-context-tokens",
|
||||
type=int,
|
||||
default=None,
|
||||
dest="max_context_tokens",
|
||||
help="Max context tokens for the model. Default is 0.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default=None,
|
||||
dest="model_path",
|
||||
help="Direct path to the SimulStreaming Whisper .pt model file. Overrides --model for SimulStreaming backend.",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--preload-model-count",
|
||||
type=int,
|
||||
default=1,
|
||||
dest="preload_model_count",
|
||||
help="Optional. Number of models to preload in memory to speed up loading (set up to the expected number of concurrent instances).",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-backend",
|
||||
type=str,
|
||||
default="transformers",
|
||||
help="transformers or ctranslate2",
|
||||
)
|
||||
|
||||
simulstreaming_group.add_argument(
|
||||
"--nllb-size",
|
||||
type=str,
|
||||
default="600M",
|
||||
help="600M or 1.3B",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.transcription = not args.no_transcription
|
||||
args.vad = not args.no_vad
|
||||
delattr(args, 'no_transcription')
|
||||
delattr(args, 'no_vad')
|
||||
|
||||
if args.backend_policy == "1":
|
||||
args.backend_policy = "simulstreaming"
|
||||
elif args.backend_policy == "2":
|
||||
args.backend_policy = "localagreement"
|
||||
|
||||
return args
|
||||
103
whisperlivekit/remove_silences.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from time import time
|
||||
import re
|
||||
|
||||
MIN_SILENCE_DURATION = 4 #in seconds
|
||||
END_SILENCE_DURATION = 8 #in seconds. you should keep it important to not have false positive when the model lag is important
|
||||
END_SILENCE_DURATION_VAC = 3 #VAC is good at detecting silences, but we want to skip the smallest silences
|
||||
|
||||
def blank_to_silence(tokens):
|
||||
full_string = ''.join([t.text for t in tokens])
|
||||
patterns = [re.compile(r'(?:\s*\[BLANK_AUDIO\]\s*)+'), re.compile(r'(?:\s*\[typing\]\s*)+')]
|
||||
matches = []
|
||||
for pattern in patterns:
|
||||
for m in pattern.finditer(full_string):
|
||||
matches.append({
|
||||
'start': m.start(),
|
||||
'end': m.end()
|
||||
})
|
||||
if matches:
|
||||
# cleaned = pattern.sub(' ', full_string).strip()
|
||||
# print("Cleaned:", cleaned)
|
||||
cumulated_len = 0
|
||||
silence_token = None
|
||||
cleaned_tokens = []
|
||||
for token in tokens:
|
||||
if matches:
|
||||
start = cumulated_len
|
||||
end = cumulated_len + len(token.text)
|
||||
cumulated_len = end
|
||||
if start >= matches[0]['start'] and end <= matches[0]['end']:
|
||||
if silence_token: #previous token was already silence
|
||||
silence_token.start = min(silence_token.start, token.start)
|
||||
silence_token.end = max(silence_token.end, token.end)
|
||||
else: #new silence
|
||||
silence_token = ASRToken(
|
||||
start=token.start,
|
||||
end=token.end,
|
||||
speaker=-2,
|
||||
)
|
||||
else:
|
||||
if silence_token: #there was silence but no more
|
||||
if silence_token.duration() >= MIN_SILENCE_DURATION:
|
||||
cleaned_tokens.append(
|
||||
silence_token
|
||||
)
|
||||
silence_token = None
|
||||
matches.pop(0)
|
||||
cleaned_tokens.append(token)
|
||||
# print(cleaned_tokens)
|
||||
return cleaned_tokens
|
||||
return tokens
|
||||
|
||||
def no_token_to_silence(tokens):
|
||||
new_tokens = []
|
||||
silence_token = None
|
||||
for token in tokens:
|
||||
if token.speaker == -2:
|
||||
if new_tokens and new_tokens[-1].speaker == -2: #if token is silence and previous one too
|
||||
new_tokens[-1].end = token.end
|
||||
else:
|
||||
new_tokens.append(token)
|
||||
|
||||
last_end = new_tokens[-1].end if new_tokens else 0.0
|
||||
if token.start - last_end >= MIN_SILENCE_DURATION: #if token is not silence but important gap
|
||||
if new_tokens and new_tokens[-1].speaker == -2:
|
||||
new_tokens[-1].end = token.start
|
||||
else:
|
||||
silence_token = ASRToken(
|
||||
start=last_end,
|
||||
end=token.start,
|
||||
speaker=-2,
|
||||
)
|
||||
new_tokens.append(silence_token)
|
||||
|
||||
if token.speaker != -2:
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def ends_with_silence(tokens, beg_loop, vac_detected_silence):
|
||||
current_time = time() - (beg_loop if beg_loop else 0.0)
|
||||
last_token = tokens[-1]
|
||||
if vac_detected_silence or (current_time - last_token.end >= END_SILENCE_DURATION):
|
||||
if last_token.speaker == -2:
|
||||
last_token.end = current_time
|
||||
else:
|
||||
tokens.append(
|
||||
ASRToken(
|
||||
start=tokens[-1].end,
|
||||
end=current_time,
|
||||
speaker=-2,
|
||||
)
|
||||
)
|
||||
return tokens
|
||||
|
||||
|
||||
def handle_silences(tokens, beg_loop, vac_detected_silence):
|
||||
if not tokens:
|
||||
return []
|
||||
tokens = blank_to_silence(tokens) #useful for simulstreaming backend which tends to generate [BLANK_AUDIO] text
|
||||
tokens = no_token_to_silence(tokens)
|
||||
tokens = ends_with_silence(tokens, beg_loop, vac_detected_silence)
|
||||
return tokens
|
||||
|
||||
60
whisperlivekit/result_diarization.md
Normal file
@@ -0,0 +1,60 @@
|
||||
########### WHAT IS PRODUCED: ###########
|
||||
|
||||
SPEAKER 1 0:00:04 - 0:00:06
|
||||
Transcription technology has improved so much in the past
|
||||
|
||||
SPEAKER 1 0:00:07 - 0:00:12
|
||||
years. Have you noticed how accurate real-time speech detects is now?
|
||||
|
||||
SPEAKER 2 0:00:12 - 0:00:12
|
||||
Absolutely
|
||||
|
||||
SPEAKER 1 0:00:13 - 0:00:13
|
||||
.
|
||||
|
||||
SPEAKER 2 0:00:14 - 0:00:14
|
||||
I
|
||||
|
||||
SPEAKER 1 0:00:14 - 0:00:17
|
||||
use it all the time for taking notes during meetings.
|
||||
|
||||
SPEAKER 2 0:00:17 - 0:00:17
|
||||
It
|
||||
|
||||
SPEAKER 1 0:00:17 - 0:00:22
|
||||
's amazing how it can recognize different speakers, and even add punctuation.
|
||||
|
||||
SPEAKER 2 0:00:22 - 0:00:22
|
||||
Yeah
|
||||
|
||||
SPEAKER 1 0:00:23 - 0:00:26
|
||||
, but sometimes noise can still cause mistakes.
|
||||
|
||||
SPEAKER 3 0:00:26 - 0:00:27
|
||||
Does
|
||||
|
||||
SPEAKER 1 0:00:27 - 0:00:28
|
||||
this system handle that
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:29
|
||||
?
|
||||
|
||||
SPEAKER 3 0:00:29 - 0:00:29
|
||||
It
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:33
|
||||
does a pretty good job filtering noise, especially with models that use voice activity
|
||||
|
||||
########### WHAT SHOULD BE PRODUCED: ###########
|
||||
|
||||
SPEAKER 1 0:00:04 - 0:00:12
|
||||
Transcription technology has improved so much in the past years. Have you noticed how accurate real-time speech detects is now?
|
||||
|
||||
SPEAKER 2 0:00:12 - 0:00:22
|
||||
Absolutely. I use it all the time for taking notes during meetings. It's amazing how it can recognize different speakers, and even add punctuation.
|
||||
|
||||
SPEAKER 3 0:00:22 - 0:00:28
|
||||
Yeah, but sometimes noise can still cause mistakes. Does this system handle that well?
|
||||
|
||||
SPEAKER 1 0:00:29 - 0:00:29
|
||||
It does a pretty good job filtering noise, especially with models that use voice activity
|
||||
257
whisperlivekit/results_formater.py
Normal file
@@ -0,0 +1,257 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from whisperlivekit.remove_silences import handle_silences
|
||||
from whisperlivekit.timed_objects import Line, format_time, SpeakerSegment
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
CHECK_AROUND = 4
|
||||
DEBUG = False
|
||||
|
||||
def next_punctuation_change(i, tokens):
|
||||
for ind in range(i+1, min(len(tokens), i+CHECK_AROUND+1)):
|
||||
if tokens[ind].is_punctuation():
|
||||
return ind
|
||||
return None
|
||||
|
||||
def next_speaker_change(i, tokens, speaker):
|
||||
for ind in range(i-1, max(0, i-CHECK_AROUND)-1, -1):
|
||||
token = tokens[ind]
|
||||
if token.is_punctuation():
|
||||
break
|
||||
if token.speaker != speaker:
|
||||
return ind, token.speaker
|
||||
return None, speaker
|
||||
|
||||
def new_line(
|
||||
token,
|
||||
):
|
||||
return Line(
|
||||
speaker = token.corrected_speaker,
|
||||
text = token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else ""),
|
||||
start = token.start,
|
||||
end = token.end,
|
||||
detected_language=token.detected_language
|
||||
)
|
||||
|
||||
def append_token_to_last_line(lines, sep, token):
|
||||
if not lines:
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
if token.text:
|
||||
lines[-1].text += sep + token.text + (f"[{format_time(token.start)} : {format_time(token.end)}]" if DEBUG else "")
|
||||
lines[-1].end = token.end
|
||||
if not lines[-1].detected_language and token.detected_language:
|
||||
lines[-1].detected_language = token.detected_language
|
||||
|
||||
def extract_number(s) -> int:
|
||||
"""Extract number from speaker string (for diart compatibility)."""
|
||||
if isinstance(s, int):
|
||||
return s
|
||||
m = re.search(r'\d+', str(s))
|
||||
return int(m.group()) if m else 0
|
||||
|
||||
def concatenate_speakers(segments: List[SpeakerSegment]) -> List[dict]:
|
||||
"""Concatenate consecutive segments from the same speaker."""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
# Get speaker number from first segment
|
||||
first_speaker = extract_number(segments[0].speaker)
|
||||
segments_concatenated = [{"speaker": first_speaker + 1, "begin": segments[0].start, "end": segments[0].end}]
|
||||
|
||||
for segment in segments[1:]:
|
||||
speaker = extract_number(segment.speaker) + 1
|
||||
if segments_concatenated[-1]['speaker'] != speaker:
|
||||
segments_concatenated.append({"speaker": speaker, "begin": segment.start, "end": segment.end})
|
||||
else:
|
||||
segments_concatenated[-1]['end'] = segment.end
|
||||
|
||||
return segments_concatenated
|
||||
|
||||
def add_speaker_to_tokens_with_punctuation(segments: List[SpeakerSegment], tokens: list) -> list:
|
||||
"""Assign speakers to tokens with punctuation-aware boundary adjustment."""
|
||||
punctuation_marks = {'.', '!', '?'}
|
||||
punctuation_tokens = [token for token in tokens if token.text.strip() in punctuation_marks]
|
||||
segments_concatenated = concatenate_speakers(segments)
|
||||
|
||||
for ind, segment in enumerate(segments_concatenated):
|
||||
for i, punctuation_token in enumerate(punctuation_tokens):
|
||||
if punctuation_token.start > segment['end']:
|
||||
after_length = punctuation_token.start - segment['end']
|
||||
before_length = segment['end'] - punctuation_tokens[i - 1].end if i > 0 else float('inf')
|
||||
if before_length > after_length:
|
||||
segment['end'] = punctuation_token.start
|
||||
if i < len(punctuation_tokens) - 1 and ind + 1 < len(segments_concatenated):
|
||||
segments_concatenated[ind + 1]['begin'] = punctuation_token.start
|
||||
else:
|
||||
segment['end'] = punctuation_tokens[i - 1].end if i > 0 else segment['end']
|
||||
if i < len(punctuation_tokens) - 1 and ind - 1 >= 0:
|
||||
segments_concatenated[ind - 1]['begin'] = punctuation_tokens[i - 1].end
|
||||
break
|
||||
|
||||
# Ensure non-overlapping tokens
|
||||
last_end = 0.0
|
||||
for token in tokens:
|
||||
start = max(last_end + 0.01, token.start)
|
||||
token.start = start
|
||||
token.end = max(start, token.end)
|
||||
last_end = token.end
|
||||
|
||||
# Assign speakers based on adjusted segments
|
||||
ind_last_speaker = 0
|
||||
for segment in segments_concatenated:
|
||||
for i, token in enumerate(tokens[ind_last_speaker:]):
|
||||
if token.end <= segment['end']:
|
||||
token.speaker = segment['speaker']
|
||||
ind_last_speaker = i + 1
|
||||
elif token.start > segment['end']:
|
||||
break
|
||||
|
||||
return tokens
|
||||
|
||||
def assign_speakers_to_tokens(tokens: list, segments: List[SpeakerSegment], use_punctuation_split: bool = False) -> list:
|
||||
"""
|
||||
Assign speakers to tokens based on timing overlap with speaker segments.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens with timing information
|
||||
segments: List of speaker segments
|
||||
use_punctuation_split: Whether to use punctuation for boundary refinement
|
||||
|
||||
Returns:
|
||||
List of tokens with speaker assignments
|
||||
"""
|
||||
if not segments or not tokens:
|
||||
logger.debug("No segments or tokens available for speaker assignment")
|
||||
return tokens
|
||||
|
||||
logger.debug(f"Assigning speakers to {len(tokens)} tokens using {len(segments)} segments")
|
||||
|
||||
if not use_punctuation_split:
|
||||
# Simple overlap-based assignment
|
||||
for token in tokens:
|
||||
token.speaker = -1 # Default to no speaker
|
||||
for segment in segments:
|
||||
# Check for timing overlap
|
||||
if not (segment.end <= token.start or segment.start >= token.end):
|
||||
speaker_num = extract_number(segment.speaker)
|
||||
token.speaker = speaker_num + 1 # Convert to 1-based indexing
|
||||
break
|
||||
else:
|
||||
# Use punctuation-aware assignment
|
||||
tokens = add_speaker_to_tokens_with_punctuation(segments, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def format_output(state, silence, args, sep):
|
||||
diarization = args.diarization
|
||||
disable_punctuation_split = args.disable_punctuation_split
|
||||
tokens = state.tokens
|
||||
translation_validated_segments = state.translation_validated_segments # Here we will attribute the speakers only based on the timestamps of the segments
|
||||
last_validated_token = state.last_validated_token
|
||||
|
||||
last_speaker = abs(state.last_speaker)
|
||||
undiarized_text = []
|
||||
tokens = handle_silences(tokens, state.beg_loop, silence)
|
||||
|
||||
# Assign speakers to tokens based on segments stored in state
|
||||
if False and diarization and state.diarization_segments:
|
||||
use_punctuation_split = args.punctuation_split if hasattr(args, 'punctuation_split') else False
|
||||
tokens = assign_speakers_to_tokens(tokens, state.diarization_segments, use_punctuation_split=use_punctuation_split)
|
||||
for i in range(last_validated_token, len(tokens)):
|
||||
token = tokens[i]
|
||||
speaker = int(token.speaker)
|
||||
token.corrected_speaker = speaker
|
||||
if True or not diarization:
|
||||
if speaker == -1: #Speaker -1 means no attributed by diarization. In the frontend, it should appear under 'Speaker 1'
|
||||
token.corrected_speaker = 1
|
||||
token.validated_speaker = True
|
||||
else:
|
||||
if token.speaker == -1:
|
||||
undiarized_text.append(token.text)
|
||||
elif token.is_punctuation():
|
||||
state.last_punctuation_index = i
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = True
|
||||
elif state.last_punctuation_index == i-1:
|
||||
if token.speaker != last_speaker:
|
||||
token.corrected_speaker = token.speaker
|
||||
token.validated_speaker = True
|
||||
# perfect, diarization perfectly aligned
|
||||
else:
|
||||
speaker_change_pos, new_speaker = next_speaker_change(i, tokens, speaker)
|
||||
if speaker_change_pos:
|
||||
# Corrects delay:
|
||||
# That was the idea. <Okay> haha |SPLIT SPEAKER| that's a good one
|
||||
# should become:
|
||||
# That was the idea. |SPLIT SPEAKER| <Okay> haha that's a good one
|
||||
token.corrected_speaker = new_speaker
|
||||
token.validated_speaker = True
|
||||
elif speaker != last_speaker:
|
||||
if not (speaker == -2 or last_speaker == -2):
|
||||
if next_punctuation_change(i, tokens):
|
||||
# Corrects advance:
|
||||
# Are you |SPLIT SPEAKER| <okay>? yeah, sure. Absolutely
|
||||
# should become:
|
||||
# Are you <okay>? |SPLIT SPEAKER| yeah, sure. Absolutely
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = True
|
||||
else: #Problematic, except if the language has no punctuation. We append to previous line, except if disable_punctuation_split is set to True.
|
||||
if not disable_punctuation_split:
|
||||
token.corrected_speaker = last_speaker
|
||||
token.validated_speaker = False
|
||||
if token.validated_speaker:
|
||||
state.last_validated_token = i
|
||||
state.last_speaker = token.corrected_speaker
|
||||
|
||||
last_speaker = 1
|
||||
|
||||
lines = []
|
||||
for token in tokens:
|
||||
if token.corrected_speaker != -1:
|
||||
if int(token.corrected_speaker) != int(last_speaker):
|
||||
lines.append(new_line(token))
|
||||
else:
|
||||
append_token_to_last_line(lines, sep, token)
|
||||
|
||||
last_speaker = token.corrected_speaker
|
||||
|
||||
if lines:
|
||||
unassigned_translated_segments = []
|
||||
for ts in translation_validated_segments:
|
||||
assigned = False
|
||||
for line in lines:
|
||||
if ts and ts.overlaps_with(line):
|
||||
if ts.is_within(line):
|
||||
line.translation += ts.text + ' '
|
||||
assigned = True
|
||||
break
|
||||
else:
|
||||
ts0, ts1 = ts.approximate_cut_at(line.end)
|
||||
if ts0 and line.overlaps_with(ts0):
|
||||
line.translation += ts0.text + ' '
|
||||
if ts1:
|
||||
unassigned_translated_segments.append(ts1)
|
||||
assigned = True
|
||||
break
|
||||
if not assigned:
|
||||
unassigned_translated_segments.append(ts)
|
||||
|
||||
if unassigned_translated_segments:
|
||||
for line in lines:
|
||||
remaining_segments = []
|
||||
for ts in unassigned_translated_segments:
|
||||
if ts and ts.overlaps_with(line):
|
||||
line.translation += ts.text + ' '
|
||||
else:
|
||||
remaining_segments.append(ts)
|
||||
unassigned_translated_segments = remaining_segments #maybe do smth in the future about that
|
||||
|
||||
if state.buffer_transcription and lines:
|
||||
lines[-1].end = max(state.buffer_transcription.end, lines[-1].end)
|
||||
|
||||
return lines, undiarized_text
|
||||
294
whisperlivekit/silero_vad_iterator.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
"""
|
||||
Code is adapted from silero-vad v6: https://github.com/snakers4/silero-vad
|
||||
"""
|
||||
|
||||
def init_jit_model(model_path: str, device=torch.device('cpu')):
|
||||
"""Load a JIT model from file."""
|
||||
model = torch.jit.load(model_path, map_location=device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class OnnxWrapper():
|
||||
"""ONNX Runtime wrapper for Silero VAD model."""
|
||||
|
||||
def __init__(self, path, force_onnx_cpu=False):
|
||||
global np
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
|
||||
opts = onnxruntime.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
|
||||
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
|
||||
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
|
||||
else:
|
||||
self.session = onnxruntime.InferenceSession(path, sess_options=opts)
|
||||
|
||||
self.reset_states()
|
||||
if '16k' in path:
|
||||
warnings.warn('This model support only 16000 sampling rate!')
|
||||
self.sample_rates = [16000]
|
||||
else:
|
||||
self.sample_rates = [8000, 16000]
|
||||
|
||||
def _validate_input(self, x, sr: int):
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 2:
|
||||
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||
|
||||
if sr != 16000 and (sr % 16000 == 0):
|
||||
step = sr // 16000
|
||||
x = x[:,::step]
|
||||
sr = 16000
|
||||
|
||||
if sr not in self.sample_rates:
|
||||
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
|
||||
if sr / x.shape[1] > 31.25:
|
||||
raise ValueError("Input audio chunk is too short")
|
||||
|
||||
return x, sr
|
||||
|
||||
def reset_states(self, batch_size=1):
|
||||
self._state = torch.zeros((2, batch_size, 128)).float()
|
||||
self._context = torch.zeros(0)
|
||||
self._last_sr = 0
|
||||
self._last_batch_size = 0
|
||||
|
||||
def __call__(self, x, sr: int):
|
||||
|
||||
x, sr = self._validate_input(x, sr)
|
||||
num_samples = 512 if sr == 16000 else 256
|
||||
|
||||
if x.shape[-1] != num_samples:
|
||||
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
|
||||
|
||||
batch_size = x.shape[0]
|
||||
context_size = 64 if sr == 16000 else 32
|
||||
|
||||
if not self._last_batch_size:
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_sr) and (self._last_sr != sr):
|
||||
self.reset_states(batch_size)
|
||||
if (self._last_batch_size) and (self._last_batch_size != batch_size):
|
||||
self.reset_states(batch_size)
|
||||
|
||||
if not len(self._context):
|
||||
self._context = torch.zeros(batch_size, context_size)
|
||||
|
||||
x = torch.cat([self._context, x], dim=1)
|
||||
if sr in [8000, 16000]:
|
||||
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
|
||||
ort_outs = self.session.run(None, ort_inputs)
|
||||
out, state = ort_outs
|
||||
self._state = torch.from_numpy(state)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self._context = x[..., -context_size:]
|
||||
self._last_sr = sr
|
||||
self._last_batch_size = batch_size
|
||||
|
||||
out = torch.from_numpy(out)
|
||||
return out
|
||||
|
||||
|
||||
def load_silero_vad(model_path: str = None, onnx: bool = False, opset_version: int = 16):
|
||||
"""
|
||||
Load Silero VAD model (JIT or ONNX).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_path : str, optional
|
||||
Path to model file. If None, uses default bundled model.
|
||||
onnx : bool, default False
|
||||
Whether to use ONNX runtime (requires onnxruntime package).
|
||||
opset_version : int, default 16
|
||||
ONNX opset version (15 or 16). Only used if onnx=True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
model
|
||||
Loaded VAD model (JIT or ONNX wrapper)
|
||||
"""
|
||||
available_ops = [15, 16]
|
||||
if onnx and opset_version not in available_ops:
|
||||
raise Exception(f'Available ONNX opset_version: {available_ops}')
|
||||
if model_path is None:
|
||||
current_dir = Path(__file__).parent
|
||||
data_dir = current_dir / 'vad_models'
|
||||
|
||||
if onnx:
|
||||
if opset_version == 16:
|
||||
model_name = 'silero_vad.onnx'
|
||||
else:
|
||||
model_name = f'silero_vad_16k_op{opset_version}.onnx'
|
||||
else:
|
||||
model_name = 'silero_vad.jit'
|
||||
|
||||
model_path = data_dir / model_name
|
||||
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found: {model_path}\n"
|
||||
f"Please ensure the whisperlivekit/vad_models/ directory contains the model files."
|
||||
)
|
||||
else:
|
||||
model_path = Path(model_path)
|
||||
if onnx:
|
||||
try:
|
||||
model = OnnxWrapper(str(model_path), force_onnx_cpu=True)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"ONNX runtime not available. Install with: pip install onnxruntime\n"
|
||||
"Or use JIT model by setting onnx=False"
|
||||
)
|
||||
else:
|
||||
model = init_jit_model(str(model_path))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class VADIterator:
|
||||
"""
|
||||
Voice Activity Detection iterator for streaming audio.
|
||||
|
||||
This is the Silero VAD v6 implementation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
threshold: float = 0.5,
|
||||
sampling_rate: int = 16000,
|
||||
min_silence_duration_ms: int = 100,
|
||||
speech_pad_ms: int = 30
|
||||
):
|
||||
|
||||
"""
|
||||
Class for stream imitation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: preloaded .jit/.onnx silero VAD model
|
||||
|
||||
threshold: float (default - 0.5)
|
||||
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
|
||||
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||
|
||||
sampling_rate: int (default - 16000)
|
||||
Currently silero VAD models support 8000 and 16000 sample rates
|
||||
|
||||
min_silence_duration_ms: int (default - 100 milliseconds)
|
||||
In the end of each speech chunk wait for min_silence_duration_ms before separating it
|
||||
|
||||
speech_pad_ms: int (default - 30 milliseconds)
|
||||
Final speech chunks are padded by speech_pad_ms each side
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
self.threshold = threshold
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
if sampling_rate not in [8000, 16000]:
|
||||
raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
|
||||
|
||||
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
||||
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
||||
self.reset_states()
|
||||
|
||||
def reset_states(self):
|
||||
|
||||
self.model.reset_states()
|
||||
self.triggered = False
|
||||
self.temp_end = 0
|
||||
self.current_sample = 0
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, x, return_seconds=False, time_resolution: int = 1):
|
||||
"""
|
||||
x: torch.Tensor
|
||||
audio chunk (see examples in repo)
|
||||
|
||||
return_seconds: bool (default - False)
|
||||
whether return timestamps in seconds (default - samples)
|
||||
|
||||
time_resolution: int (default - 1)
|
||||
time resolution of speech coordinates when requested as seconds
|
||||
"""
|
||||
|
||||
if not torch.is_tensor(x):
|
||||
try:
|
||||
x = torch.Tensor(x)
|
||||
except:
|
||||
raise TypeError("Audio cannot be casted to tensor. Cast it manually")
|
||||
|
||||
window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
|
||||
self.current_sample += window_size_samples
|
||||
|
||||
speech_prob = self.model(x, self.sampling_rate).item()
|
||||
|
||||
if (speech_prob >= self.threshold) and self.temp_end:
|
||||
self.temp_end = 0
|
||||
|
||||
if (speech_prob >= self.threshold) and not self.triggered:
|
||||
self.triggered = True
|
||||
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
||||
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, time_resolution)}
|
||||
|
||||
if (speech_prob < self.threshold - 0.15) and self.triggered:
|
||||
if not self.temp_end:
|
||||
self.temp_end = self.current_sample
|
||||
if self.current_sample - self.temp_end < self.min_silence_samples:
|
||||
return None
|
||||
else:
|
||||
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
||||
self.temp_end = 0
|
||||
self.triggered = False
|
||||
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, time_resolution)}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class FixedVADIterator(VADIterator):
|
||||
"""
|
||||
Fixed VAD Iterator that handles variable-length audio chunks, not only exactly 512 frames at once.
|
||||
"""
|
||||
|
||||
def reset_states(self):
|
||||
super().reset_states()
|
||||
self.buffer = np.array([], dtype=np.float32)
|
||||
|
||||
def __call__(self, x, return_seconds=False):
|
||||
self.buffer = np.append(self.buffer, x)
|
||||
ret = None
|
||||
while len(self.buffer) >= 512:
|
||||
r = super().__call__(self.buffer[:512], return_seconds=return_seconds)
|
||||
self.buffer = self.buffer[512:]
|
||||
if ret is None:
|
||||
ret = r
|
||||
elif r is not None:
|
||||
if "end" in r:
|
||||
ret["end"] = r["end"]
|
||||
if "start" in r and "end" in ret:
|
||||
del ret["end"]
|
||||
return ret if ret != {} else None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = load_silero_vad(onnx=False)
|
||||
vad = FixedVADIterator(model)
|
||||
|
||||
audio_buffer = np.array([0] * 512, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
print(f" 512 samples: {result}")
|
||||
|
||||
# test with 511 samples
|
||||
audio_buffer = np.array([0] * 511, dtype=np.float32)
|
||||
result = vad(audio_buffer)
|
||||
6
whisperlivekit/simul_whisper/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .backend import SimulStreamingASR, SimulStreamingOnlineProcessor
|
||||
|
||||
__all__ = [
|
||||
"SimulStreamingASR",
|
||||
"SimulStreamingOnlineProcessor",
|
||||
]
|
||||
355
whisperlivekit/simul_whisper/backend.py
Normal file
@@ -0,0 +1,355 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing import List, Tuple, Optional
|
||||
import platform
|
||||
from whisperlivekit.timed_objects import ASRToken, Transcript, ChangeSpeaker
|
||||
from whisperlivekit.warmup import load_file
|
||||
from whisperlivekit.whisper import load_model, tokenizer
|
||||
from whisperlivekit.whisper.audio import TOKENS_PER_SECOND
|
||||
import os
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from whisperlivekit.model_paths import model_path_and_type, resolve_model_path
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
import torch
|
||||
from whisperlivekit.simul_whisper.config import AlignAttConfig
|
||||
from whisperlivekit.simul_whisper.simul_whisper import AlignAtt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
HAS_MLX_WHISPER = mlx_backend_available(warn_on_missing=True)
|
||||
if HAS_MLX_WHISPER:
|
||||
from .mlx_encoder import mlx_model_mapping, load_mlx_encoder
|
||||
else:
|
||||
mlx_model_mapping = {}
|
||||
HAS_FASTER_WHISPER = faster_backend_available(warn_on_missing=not HAS_MLX_WHISPER)
|
||||
if HAS_FASTER_WHISPER:
|
||||
from faster_whisper import WhisperModel
|
||||
else:
|
||||
WhisperModel = None
|
||||
|
||||
MIN_DURATION_REAL_SILENCE = 5
|
||||
|
||||
class SimulStreamingOnlineProcessor:
|
||||
SAMPLING_RATE = 16000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr,
|
||||
logfile=sys.stderr,
|
||||
):
|
||||
self.asr = asr
|
||||
self.logfile = logfile
|
||||
self.end = 0.0
|
||||
self.buffer = []
|
||||
self.committed: List[ASRToken] = []
|
||||
self.last_result_tokens: List[ASRToken] = []
|
||||
self.load_new_backend()
|
||||
|
||||
#can be moved
|
||||
if asr.tokenizer:
|
||||
self.model.tokenizer = asr.tokenizer
|
||||
|
||||
def load_new_backend(self):
|
||||
model = self.asr.get_new_model_instance()
|
||||
self.model = AlignAtt(
|
||||
cfg=self.asr.cfg,
|
||||
loaded_model=model,
|
||||
mlx_encoder=self.asr.mlx_encoder,
|
||||
fw_encoder=self.asr.fw_encoder,
|
||||
)
|
||||
|
||||
def start_silence(self):
|
||||
tokens, processed_upto = self.process_iter(is_last=True)
|
||||
return tokens, processed_upto
|
||||
|
||||
def end_silence(self, silence_duration, offset):
|
||||
"""
|
||||
If silences are > MIN_DURATION_REAL_SILENCE, we do a complete context clear. Otherwise, we just insert a small silence and shift the last_attend_frame
|
||||
"""
|
||||
self.end += silence_duration
|
||||
long_silence = silence_duration >= MIN_DURATION_REAL_SILENCE
|
||||
if not long_silence:
|
||||
gap_len = int(16000 * silence_duration)
|
||||
if gap_len > 0:
|
||||
gap_silence = torch.zeros(gap_len)
|
||||
self.model.insert_audio(gap_silence)
|
||||
if long_silence:
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.global_time_offset = silence_duration + offset
|
||||
|
||||
|
||||
|
||||
def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time):
|
||||
"""Append an audio chunk to be processed by SimulStreaming."""
|
||||
|
||||
# Convert numpy array to torch tensor
|
||||
audio_tensor = torch.from_numpy(audio).float()
|
||||
self.end = audio_stream_end_time #Only to be aligned with what happens in whisperstreaming backend.
|
||||
self.model.insert_audio(audio_tensor)
|
||||
|
||||
def new_speaker(self, change_speaker: ChangeSpeaker):
|
||||
self.process_iter(is_last=True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
self.model.speaker = change_speaker.speaker
|
||||
self.global_time_offset = change_speaker.start
|
||||
|
||||
def get_buffer(self):
|
||||
concat_buffer = Transcript.from_tokens(tokens= self.buffer, sep='')
|
||||
return concat_buffer
|
||||
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
Process accumulated audio chunks using SimulStreaming.
|
||||
|
||||
Returns a tuple: (list of committed ASRToken objects, float representing the audio processed up to time).
|
||||
"""
|
||||
try:
|
||||
timestamped_words = self.model.infer(is_last=is_last)
|
||||
if self.model.cfg.language == "auto" and timestamped_words and timestamped_words[0].detected_language == None:
|
||||
self.buffer.extend(timestamped_words)
|
||||
return [], self.end
|
||||
|
||||
self.committed.extend(timestamped_words)
|
||||
self.buffer = []
|
||||
return timestamped_words, self.end
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"SimulStreaming processing error: {e}")
|
||||
return [], self.end
|
||||
|
||||
def warmup(self, audio, init_prompt=""):
|
||||
"""Warmup the SimulStreaming model."""
|
||||
try:
|
||||
self.model.insert_audio(audio)
|
||||
self.model.infer(True)
|
||||
self.model.refresh_segment(complete=True)
|
||||
logger.info("SimulStreaming model warmed up successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"SimulStreaming warmup failed: {e}")
|
||||
|
||||
def __del__(self):
|
||||
# free the model and add a new model to stack.
|
||||
# del self.model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
# self.asr.new_model_to_stack()
|
||||
self.model.remove_hooks()
|
||||
|
||||
class SimulStreamingASR():
|
||||
"""SimulStreaming backend with AlignAtt policy."""
|
||||
sep = ""
|
||||
|
||||
def __init__(self, logfile=sys.stderr, **kwargs):
|
||||
self.logfile = logfile
|
||||
self.transcribe_kargs = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if self.decoder_type is None:
|
||||
self.decoder_type = 'greedy' if self.beams == 1 else 'beam'
|
||||
|
||||
self.fast_encoder = False
|
||||
self._resolved_model_path = None
|
||||
self.encoder_backend = "whisper"
|
||||
preferred_backend = getattr(self, "backend", "auto")
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = None, True, True
|
||||
if self.model_path:
|
||||
resolved_model_path = resolve_model_path(self.model_path)
|
||||
self._resolved_model_path = resolved_model_path
|
||||
self.model_path = str(resolved_model_path)
|
||||
self.pytorch_path, compatible_whisper_mlx, compatible_faster_whisper = model_path_and_type(resolved_model_path)
|
||||
if self.pytorch_path:
|
||||
self.model_name = self.pytorch_path.stem
|
||||
else:
|
||||
self.model_name = Path(self.model_path).stem
|
||||
raise FileNotFoundError(
|
||||
f"No PyTorch checkpoint (.pt/.bin/.safetensors) found under {self.model_path}"
|
||||
)
|
||||
elif self.model_size is not None:
|
||||
model_mapping = {
|
||||
'tiny': './tiny.pt',
|
||||
'base': './base.pt',
|
||||
'small': './small.pt',
|
||||
'medium': './medium.pt',
|
||||
'medium.en': './medium.en.pt',
|
||||
'large-v1': './large-v1.pt',
|
||||
'base.en': './base.en.pt',
|
||||
'small.en': './small.en.pt',
|
||||
'tiny.en': './tiny.en.pt',
|
||||
'large-v2': './large-v2.pt',
|
||||
'large-v3': './large-v3.pt',
|
||||
'large': './large-v3.pt'
|
||||
}
|
||||
self.model_name = self.model_size
|
||||
else:
|
||||
raise ValueError("Either model_size or model_path must be specified for SimulStreaming.")
|
||||
|
||||
is_multilingual = not self.model_name.endswith(".en")
|
||||
|
||||
self.encoder_backend = self._resolve_encoder_backend(
|
||||
preferred_backend,
|
||||
compatible_whisper_mlx,
|
||||
compatible_faster_whisper,
|
||||
)
|
||||
self.fast_encoder = self.encoder_backend in ("mlx-whisper", "faster-whisper")
|
||||
if self.encoder_backend == "whisper":
|
||||
self.disable_fast_encoder = True
|
||||
|
||||
self.cfg = AlignAttConfig(
|
||||
tokenizer_is_multilingual= is_multilingual,
|
||||
segment_length=self.min_chunk_size,
|
||||
frame_threshold=self.frame_threshold,
|
||||
language=self.lan,
|
||||
audio_max_len=self.audio_max_len,
|
||||
audio_min_len=self.audio_min_len,
|
||||
cif_ckpt_path=self.cif_ckpt_path,
|
||||
decoder_type="beam",
|
||||
beam_size=self.beams,
|
||||
task=self.direct_english_translation,
|
||||
never_fire=self.never_fire,
|
||||
init_prompt=self.init_prompt,
|
||||
max_context_tokens=self.max_context_tokens,
|
||||
static_init_prompt=self.static_init_prompt,
|
||||
)
|
||||
|
||||
# Set up tokenizer for translation if needed
|
||||
if self.direct_english_translation:
|
||||
self.tokenizer = self.set_translate_task()
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
|
||||
|
||||
|
||||
self.mlx_encoder, self.fw_encoder = None, None
|
||||
if self.encoder_backend == "mlx-whisper":
|
||||
print('Simulstreaming will use MLX whisper to increase encoding speed.')
|
||||
if self._resolved_model_path is not None:
|
||||
mlx_model = str(self._resolved_model_path)
|
||||
else:
|
||||
mlx_model = mlx_model_mapping.get(self.model_name)
|
||||
if not mlx_model:
|
||||
raise FileNotFoundError(
|
||||
f"MLX Whisper backend requested but no compatible weights found for model '{self.model_name}'."
|
||||
)
|
||||
self.mlx_encoder = load_mlx_encoder(path_or_hf_repo=mlx_model)
|
||||
elif self.encoder_backend == "faster-whisper":
|
||||
print('Simulstreaming will use Faster Whisper for the encoder.')
|
||||
if self._resolved_model_path is not None:
|
||||
fw_model = str(self._resolved_model_path)
|
||||
else:
|
||||
fw_model = self.model_name
|
||||
self.fw_encoder = WhisperModel(
|
||||
fw_model,
|
||||
device='auto',
|
||||
compute_type='auto',
|
||||
)
|
||||
|
||||
self.models = [self.load_model() for i in range(self.preload_model_count)]
|
||||
|
||||
|
||||
def _resolve_encoder_backend(self, preferred_backend, compatible_whisper_mlx, compatible_faster_whisper):
|
||||
choice = preferred_backend or "auto"
|
||||
if self.disable_fast_encoder:
|
||||
return "whisper"
|
||||
if choice == "whisper":
|
||||
return "whisper"
|
||||
if choice == "mlx-whisper":
|
||||
if not self._can_use_mlx(compatible_whisper_mlx):
|
||||
raise RuntimeError("mlx-whisper backend requested but MLX Whisper is unavailable or incompatible with the provided model.")
|
||||
return "mlx-whisper"
|
||||
if choice == "faster-whisper":
|
||||
if not self._can_use_faster(compatible_faster_whisper):
|
||||
raise RuntimeError("faster-whisper backend requested but Faster-Whisper is unavailable or incompatible with the provided model.")
|
||||
return "faster-whisper"
|
||||
if choice == "openai-api":
|
||||
raise ValueError("openai-api backend is only supported with the LocalAgreement policy.")
|
||||
# auto mode
|
||||
if platform.system() == "Darwin" and self._can_use_mlx(compatible_whisper_mlx):
|
||||
return "mlx-whisper"
|
||||
if self._can_use_faster(compatible_faster_whisper):
|
||||
return "faster-whisper"
|
||||
return "whisper"
|
||||
|
||||
def _has_custom_model_path(self):
|
||||
return self._resolved_model_path is not None
|
||||
|
||||
def _can_use_mlx(self, compatible_whisper_mlx):
|
||||
if not HAS_MLX_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_whisper_mlx
|
||||
return self.model_name in mlx_model_mapping
|
||||
|
||||
def _can_use_faster(self, compatible_faster_whisper):
|
||||
if not HAS_FASTER_WHISPER:
|
||||
return False
|
||||
if self._has_custom_model_path():
|
||||
return compatible_faster_whisper
|
||||
return True
|
||||
|
||||
def load_model(self):
|
||||
whisper_model = load_model(
|
||||
name=self.pytorch_path if self.pytorch_path else self.model_name,
|
||||
download_root=self.model_path,
|
||||
decoder_only=self.fast_encoder,
|
||||
custom_alignment_heads=self.custom_alignment_heads
|
||||
)
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
if warmup_audio is not None:
|
||||
warmup_audio = torch.from_numpy(warmup_audio).float()
|
||||
if self.fast_encoder:
|
||||
temp_model = AlignAtt(
|
||||
cfg=self.cfg,
|
||||
loaded_model=whisper_model,
|
||||
mlx_encoder=self.mlx_encoder,
|
||||
fw_encoder=self.fw_encoder,
|
||||
)
|
||||
temp_model.warmup(warmup_audio)
|
||||
temp_model.remove_hooks()
|
||||
else:
|
||||
# For standard encoder, use the original transcribe warmup
|
||||
warmup_audio = load_file(self.warmup_file)
|
||||
whisper_model.transcribe(warmup_audio, language=self.lan if self.lan != 'auto' else None)
|
||||
return whisper_model
|
||||
|
||||
def get_new_model_instance(self):
|
||||
"""
|
||||
SimulStreaming cannot share the same backend because it uses global forward hooks on the attention layers.
|
||||
Therefore, each user requires a separate model instance, which can be memory-intensive. To maintain speed, we preload the models into memory.
|
||||
"""
|
||||
if len(self.models) == 0:
|
||||
self.models.append(self.load_model())
|
||||
new_model = self.models.pop()
|
||||
return new_model
|
||||
# self.models[0]
|
||||
|
||||
def new_model_to_stack(self):
|
||||
self.models.append(self.load_model())
|
||||
|
||||
|
||||
def set_translate_task(self):
|
||||
"""Set up translation task."""
|
||||
if self.cfg.language == 'auto':
|
||||
raise Exception('Translation cannot be done with language = auto')
|
||||
return tokenizer.get_tokenizer(
|
||||
multilingual=True,
|
||||
language=self.cfg.language,
|
||||
num_languages=99,
|
||||
task="translate"
|
||||
)
|
||||
|
||||
def transcribe(self, audio):
|
||||
"""
|
||||
Warmup is done directly in load_model
|
||||
"""
|
||||
pass
|
||||
17
whisperlivekit/simul_whisper/beam.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from whisperlivekit.whisper.decoding import PyTorchInference
|
||||
|
||||
# extention of PyTorchInference for beam search
|
||||
class BeamPyTorchInference(PyTorchInference):
|
||||
|
||||
def _kv_modules(self):
|
||||
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
|
||||
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
|
||||
return key_modules + value_modules
|
||||
|
||||
def rearrange_kv_cache(self, source_indices):
|
||||
if source_indices != list(range(len(source_indices))):
|
||||
for module_cache_id in self._kv_modules():
|
||||
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
|
||||
from torch import Tensor
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
23
whisperlivekit/simul_whisper/config.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
@dataclass
|
||||
class AlignAttConfig():
|
||||
eval_data_path: str = "tmp"
|
||||
segment_length: float = field(default=1.0, metadata = {"help": "in second"})
|
||||
frame_threshold: int = 4
|
||||
rewind_threshold: int = 200
|
||||
audio_max_len: float = 20.0
|
||||
cif_ckpt_path: str = ""
|
||||
never_fire: bool = False
|
||||
language: str = field(default="zh")
|
||||
nonspeech_prob: float = 0.5
|
||||
audio_min_len: float = 1.0
|
||||
decoder_type: Literal["greedy","beam"] = "greedy"
|
||||
beam_size: int = 5
|
||||
task: Literal["transcribe","translate"] = "transcribe"
|
||||
tokenizer_is_multilingual: bool = False
|
||||
init_prompt: str = field(default=None)
|
||||
static_init_prompt: str = field(default=None)
|
||||
max_context_tokens: int = field(default=None)
|
||||
|
||||
65
whisperlivekit/simul_whisper/eow_detection.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
|
||||
# code for the end-of-word detection based on the CIF model proposed in Simul-Whisper
|
||||
|
||||
def load_cif(cfg, n_audio_state, device):
|
||||
"""cfg: AlignAttConfig, n_audio_state: int, device: torch.device"""
|
||||
cif_linear = torch.nn.Linear(n_audio_state, 1)
|
||||
if cfg.cif_ckpt_path is None or not cfg.cif_ckpt_path:
|
||||
if cfg.never_fire:
|
||||
never_fire = True
|
||||
always_fire = False
|
||||
else:
|
||||
always_fire = True
|
||||
never_fire = False
|
||||
else:
|
||||
always_fire = False
|
||||
never_fire = cfg.never_fire
|
||||
checkpoint = torch.load(cfg.cif_ckpt_path)
|
||||
cif_linear.load_state_dict(checkpoint)
|
||||
cif_linear.to(device)
|
||||
return cif_linear, always_fire, never_fire
|
||||
|
||||
|
||||
# from https://github.com/dqqcasia/mosst/blob/master/fairseq/models/speech_to_text/convtransformer_wav2vec_cif.py
|
||||
def resize(alphas, target_lengths, threshold=0.999):
|
||||
"""
|
||||
alpha in thresh=1.0 | (0.0, +0.21)
|
||||
target_lengths: if None, apply round and resize, else apply scaling
|
||||
"""
|
||||
# sum
|
||||
_num = alphas.sum(-1)
|
||||
num = target_lengths.float()
|
||||
# scaling
|
||||
_alphas = alphas * (num / _num)[:, None].repeat(1, alphas.size(1))
|
||||
# rm attention value that exceeds threashold
|
||||
count = 0
|
||||
while len(torch.where(_alphas > threshold)[0]):
|
||||
count += 1
|
||||
if count > 10:
|
||||
break
|
||||
xs, ys = torch.where(_alphas > threshold)
|
||||
for x, y in zip(xs, ys):
|
||||
if _alphas[x][y] >= threshold:
|
||||
mask = _alphas[x].ne(0).float()
|
||||
mean = 0.5 * _alphas[x].sum() / mask.sum()
|
||||
_alphas[x] = _alphas[x] * 0.5 + mean * mask
|
||||
|
||||
return _alphas, _num
|
||||
|
||||
def fire_at_boundary(chunked_encoder_feature: torch.Tensor, cif_linear):
|
||||
content_mel_len = chunked_encoder_feature.shape[1] # B, T, D
|
||||
alphas = cif_linear(chunked_encoder_feature).squeeze(dim=2) # B, T
|
||||
alphas = torch.sigmoid(alphas)
|
||||
decode_length = torch.round(alphas.sum(-1)).int()
|
||||
alphas, _ = resize(alphas, decode_length)
|
||||
alphas = alphas.squeeze(0) # (T, )
|
||||
threshold = 0.999
|
||||
integrate = torch.cumsum(alphas[:-1], dim=0) # ignore the peak value at the end of the content chunk
|
||||
exceed_count = integrate[-1] // threshold
|
||||
integrate = integrate - exceed_count*1.0 # minus 1 every time intergrate exceed the threshold
|
||||
important_positions = (integrate >= 0).nonzero(as_tuple=True)[0]
|
||||
if important_positions.numel() == 0:
|
||||
return False
|
||||
else:
|
||||
return important_positions[0] >= content_mel_len-2
|
||||
72
whisperlivekit/simul_whisper/mlx_encoder.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from mlx_whisper import whisper
|
||||
|
||||
mlx_model_mapping = {
|
||||
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
||||
"tiny": "mlx-community/whisper-tiny-mlx",
|
||||
"base.en": "mlx-community/whisper-base.en-mlx",
|
||||
"base": "mlx-community/whisper-base-mlx",
|
||||
"small.en": "mlx-community/whisper-small.en-mlx",
|
||||
"small": "mlx-community/whisper-small-mlx",
|
||||
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
||||
"medium": "mlx-community/whisper-medium-mlx",
|
||||
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
||||
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
||||
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
||||
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
||||
"large": "mlx-community/whisper-large-mlx",
|
||||
}
|
||||
|
||||
def load_mlx_encoder(
|
||||
path_or_hf_repo: str,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> whisper.Whisper:
|
||||
model_path = Path(path_or_hf_repo)
|
||||
if not model_path.exists():
|
||||
model_path = Path(snapshot_download(repo_id=path_or_hf_repo))
|
||||
|
||||
with open(str(model_path / "config.json"), "r") as f:
|
||||
config = json.loads(f.read())
|
||||
config.pop("model_type", None)
|
||||
quantization = config.pop("quantization", None)
|
||||
|
||||
model_args = whisper.ModelDimensions(**config)
|
||||
|
||||
wf = model_path / "weights.safetensors"
|
||||
if not wf.exists():
|
||||
wf = model_path / "weights.npz"
|
||||
weights = mx.load(str(wf))
|
||||
|
||||
model = whisper.Whisper(model_args, dtype)
|
||||
|
||||
if quantization is not None:
|
||||
class_predicate = (
|
||||
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
||||
and f"{p}.scales" in weights
|
||||
)
|
||||
nn.quantize(model, **quantization, class_predicate=class_predicate)
|
||||
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
|
||||
# we only want to load the encoder weights here.
|
||||
# Size examples: for tiny.en,
|
||||
# Decoder weights: 59110771 bytes
|
||||
# Encoder weights: 15268874 bytes
|
||||
|
||||
|
||||
encoder_weights = {}
|
||||
encoder_weights['encoder'] = weights['encoder']
|
||||
del(weights)
|
||||
|
||||
|
||||
|
||||
model.update(encoder_weights)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
660
whisperlivekit/simul_whisper/simul_whisper.py
Normal file
@@ -0,0 +1,660 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from whisperlivekit.whisper import DecodingOptions, tokenizer
|
||||
from .config import AlignAttConfig
|
||||
from whisperlivekit.timed_objects import ASRToken
|
||||
from whisperlivekit.whisper.audio import log_mel_spectrogram, TOKENS_PER_SECOND, pad_or_trim, N_SAMPLES, N_FRAMES
|
||||
from whisperlivekit.whisper.timing import median_filter
|
||||
from whisperlivekit.whisper.decoding import GreedyDecoder, BeamSearchDecoder, SuppressTokens
|
||||
from .beam import BeamPyTorchInference
|
||||
from .eow_detection import fire_at_boundary, load_cif
|
||||
import os
|
||||
from time import time
|
||||
from .token_buffer import TokenBuffer
|
||||
from whisperlivekit.backend_support import (
|
||||
mlx_backend_available,
|
||||
faster_backend_available,
|
||||
)
|
||||
|
||||
from ..timed_objects import PUNCTUATION_MARKS
|
||||
|
||||
DEC_PAD = 50257
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if mlx_backend_available():
|
||||
from mlx_whisper.audio import log_mel_spectrogram as mlx_log_mel_spectrogram
|
||||
from mlx_whisper.transcribe import pad_or_trim as mlx_pad_or_trim
|
||||
|
||||
if faster_backend_available():
|
||||
from faster_whisper.audio import pad_or_trim as fw_pad_or_trim
|
||||
from faster_whisper.feature_extractor import FeatureExtractor
|
||||
|
||||
USE_MLCORE = False
|
||||
|
||||
|
||||
def load_coreml_encoder():
|
||||
try:
|
||||
from coremltools.models import MLModel
|
||||
except ImportError:
|
||||
logger.warning("coremltools is not installed")
|
||||
return None
|
||||
COREML_ENCODER_PATH = os.environ.get("MLCORE_ENCODER_PATH", "whisperlivekit/whisper/whisper_encoder.mlpackage")
|
||||
_coreml_encoder = MLModel(COREML_ENCODER_PATH)
|
||||
spec = _coreml_encoder.get_spec()
|
||||
_coreml_input_name = spec.description.input[0].name if spec.description.input else "mel"
|
||||
_coreml_output_name = spec.description.output[0].name if spec.description.output else None
|
||||
return _coreml_encoder, _coreml_input_name, _coreml_output_name
|
||||
|
||||
|
||||
class AlignAtt:
|
||||
def __init__(
|
||||
self,
|
||||
cfg: AlignAttConfig,
|
||||
loaded_model=None,
|
||||
mlx_encoder=None,
|
||||
fw_encoder=None,
|
||||
) -> None:
|
||||
self.log_segments = 0
|
||||
|
||||
self.model = loaded_model
|
||||
self.mlx_encoder = mlx_encoder
|
||||
self.fw_encoder = fw_encoder
|
||||
if fw_encoder:
|
||||
self.fw_feature_extractor = FeatureExtractor(feature_size=self.model.dims.n_mels)
|
||||
self.coreml_encoder_tuple = None
|
||||
if USE_MLCORE:
|
||||
self.coreml_encoder_tuple = load_coreml_encoder()
|
||||
self.use_mlcore = self.coreml_encoder_tuple is not None
|
||||
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
logger.info(f"Model dimensions: {self.model.dims}")
|
||||
self.speaker = -1
|
||||
self.decode_options = DecodingOptions(
|
||||
language = cfg.language,
|
||||
without_timestamps = True,
|
||||
task=cfg.task
|
||||
)
|
||||
self.tokenizer_is_multilingual = cfg.tokenizer_is_multilingual
|
||||
self.create_tokenizer(cfg.language if cfg.language != "auto" else None)
|
||||
# self.create_tokenizer('en')
|
||||
self.detected_language = cfg.language if cfg.language != "auto" else None
|
||||
self.global_time_offset = 0.0
|
||||
self.reset_tokenizer_to_auto_next_call = False
|
||||
|
||||
self.max_text_len = self.model.dims.n_text_ctx
|
||||
self.num_decoder_layers = len(self.model.decoder.blocks)
|
||||
self.cfg = cfg
|
||||
self.l_hooks = []
|
||||
|
||||
# model to detect end-of-word boundary at the end of the segment
|
||||
self.CIFLinear, self.always_fire, self.never_fire = load_cif(cfg,
|
||||
n_audio_state=self.model.dims.n_audio_state,
|
||||
device=self.model.device)
|
||||
|
||||
# install hooks to access encoder-decoder attention
|
||||
self.dec_attns = []
|
||||
def layer_hook(module, net_input, net_output):
|
||||
# net_output[1]: B*num_head*token_len*audio_len
|
||||
t = F.softmax(net_output[1], dim=-1)
|
||||
self.dec_attns.append(t.squeeze(0))
|
||||
for b in self.model.decoder.blocks:
|
||||
hook = b.cross_attn.register_forward_hook(layer_hook)
|
||||
self.l_hooks.append(hook)
|
||||
|
||||
self.kv_cache = {}
|
||||
def kv_hook(module: torch.nn.Linear, _, net_output: torch.Tensor):
|
||||
if module.cache_id not in self.kv_cache or net_output.shape[1] > self.max_text_len:
|
||||
# save as-is, for the first token or cross attention
|
||||
self.kv_cache[module.cache_id] = net_output
|
||||
else:
|
||||
x = self.kv_cache[module.cache_id]
|
||||
self.kv_cache[module.cache_id] = torch.cat([x, net_output], dim=1).detach()
|
||||
return self.kv_cache[module.cache_id]
|
||||
|
||||
for i,b in enumerate(self.model.decoder.blocks):
|
||||
hooks = [
|
||||
b.attn.key.register_forward_hook(kv_hook),
|
||||
b.attn.value.register_forward_hook(kv_hook),
|
||||
b.cross_attn.key.register_forward_hook(kv_hook),
|
||||
b.cross_attn.value.register_forward_hook(kv_hook),
|
||||
]
|
||||
self.l_hooks.extend(hooks)
|
||||
|
||||
self.align_source = {}
|
||||
self.num_align_heads = 0
|
||||
for layer_rank, head_id in self.model.alignment_heads.indices().T:
|
||||
layer_rank = layer_rank.item()
|
||||
heads = self.align_source.get(layer_rank, [])
|
||||
heads.append((self.num_align_heads, head_id.item()))
|
||||
self.align_source[layer_rank] = heads
|
||||
self.num_align_heads += 1
|
||||
|
||||
|
||||
# tokens to be suppressed from decoding, to prevent hallucinations
|
||||
suppress_tokens = [
|
||||
self.tokenizer.transcribe,
|
||||
self.tokenizer.translate,
|
||||
self.tokenizer.sot,
|
||||
self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm,
|
||||
# self.tokenizer.eot
|
||||
self.tokenizer.no_timestamps, # added by DM
|
||||
] + list(self.tokenizer.all_language_tokens) # added by DM
|
||||
if self.tokenizer.no_speech is not None:
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
suppress_tokens = tuple(sorted(set(suppress_tokens)))
|
||||
logger.debug(f"Suppress tokens: {suppress_tokens}")
|
||||
sup_tokens = SuppressTokens(suppress_tokens)
|
||||
self.suppress_tokens = lambda logits: sup_tokens.apply(logits, None)
|
||||
# blank tokens are suppresed for new segments near the line 334
|
||||
|
||||
# it's going to be regenerated after lang id
|
||||
self.segments = []
|
||||
self.init_tokens()
|
||||
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.first_timestamp = None
|
||||
|
||||
if self.cfg.max_context_tokens is None:
|
||||
self.max_context_tokens = self.max_text_len
|
||||
else:
|
||||
self.max_context_tokens = self.cfg.max_context_tokens
|
||||
self.init_context()
|
||||
|
||||
# decoder type: greedy or beam
|
||||
if cfg.decoder_type == "greedy":
|
||||
logger.info("Using greedy decoder")
|
||||
self.token_decoder = GreedyDecoder(0.0, self.tokenizer.eot)
|
||||
self.decoder_type = "greedy"
|
||||
|
||||
elif cfg.decoder_type == "beam":
|
||||
self.decoder_type = "beam"
|
||||
self.inference = BeamPyTorchInference(self.model, self.initial_token_length)
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
|
||||
self.token_decoder = BeamSearchDecoder(inference=self.inference, eot=self.tokenizer.eot, beam_size=cfg.beam_size)
|
||||
|
||||
# Tokens to carry over to next chunk for incomplete UTF-8 characters
|
||||
self.pending_incomplete_tokens = []
|
||||
|
||||
def remove_hooks(self):
|
||||
for hook in self.l_hooks:
|
||||
hook.remove()
|
||||
|
||||
def warmup(self, audio):
|
||||
try:
|
||||
self.insert_audio(audio)
|
||||
self.infer(is_last=True)
|
||||
self.refresh_segment(complete=True)
|
||||
logger.info("Model warmed up successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"Model warmup failed: {e}")
|
||||
|
||||
def create_tokenizer(self, language=None):
|
||||
self.tokenizer = tokenizer.get_tokenizer(
|
||||
multilingual=self.tokenizer_is_multilingual,
|
||||
language=language,
|
||||
num_languages=self.model.num_languages,
|
||||
task=self.decode_options.task
|
||||
)
|
||||
|
||||
def init_context(self):
|
||||
kw = {'tokenizer': self.tokenizer,
|
||||
'device': self.model.device,
|
||||
'prefix_token_ids': [self.tokenizer.sot_prev]}
|
||||
self.context = TokenBuffer.empty(**kw)
|
||||
if self.cfg.static_init_prompt is not None:
|
||||
self.context = TokenBuffer.from_text(self.cfg.static_init_prompt, **kw)
|
||||
if self.cfg.init_prompt is not None:
|
||||
self.context.text += self.cfg.init_prompt
|
||||
|
||||
def init_tokens(self):
|
||||
logger.debug(f"init tokens, {len(self.segments)}")
|
||||
# init tokens (mandatory prompt)
|
||||
self.initial_tokens = torch.tensor(
|
||||
self.tokenizer.sot_sequence_including_notimestamps,
|
||||
dtype=torch.long,
|
||||
device=self.model.device).unsqueeze(0)
|
||||
self.initial_token_length = self.initial_tokens.shape[1]
|
||||
self.sot_index = self.tokenizer.sot_sequence.index(self.tokenizer.sot)
|
||||
# self.segments = []
|
||||
logger.debug(f"init tokens after, {len(self.segments)}")
|
||||
self.tokens = [self.initial_tokens]
|
||||
|
||||
def trim_context(self):
|
||||
logger.info("Trimming context")
|
||||
c = len(self.context.as_token_ids()) - len(self.context.prefix_token_ids)
|
||||
# logger.debug(f"c= {len(self.context.as_token_ids())}, {len(self.context.prefix_token_ids)}")
|
||||
logger.info(f"Context text: {self.context.as_text()}")
|
||||
# logger.debug(f"Context tensor: {self.context.as_tensor()}")
|
||||
l = sum(t.shape[1] for t in self.tokens) + c
|
||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
if self.cfg.static_init_prompt is None:
|
||||
after = 0
|
||||
else:
|
||||
after = len(self.cfg.static_init_prompt)
|
||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
while c > self.max_context_tokens or l > self.max_text_len - 20:
|
||||
t = self.context.trim_words(after=after)
|
||||
l -= t
|
||||
c -= t
|
||||
logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
if t == 0:
|
||||
break
|
||||
# logger.debug(f"len {l}, c {c}, max_context_tokens {self.max_context_tokens}")
|
||||
logger.info(f"Context after trim: {self.context.text} (len: {l})")
|
||||
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
if self.cfg.decoder_type == "greedy":
|
||||
logit = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
else:
|
||||
logger.debug(f"Logits shape: {tokens.shape}")
|
||||
logit = self.inference.logits(tokens, audio_features)
|
||||
return logit
|
||||
|
||||
|
||||
def refresh_segment(self, complete=False):
|
||||
|
||||
logger.debug("Refreshing segment:")
|
||||
self.init_tokens()
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.detected_language = None
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.init_context()
|
||||
logger.debug(f"Context: {self.context}")
|
||||
if not complete and len(self.segments) > 2:
|
||||
self.segments = self.segments[-2:]
|
||||
else:
|
||||
logger.debug("removing all segments.")
|
||||
self.segments = []
|
||||
self.log_segments += 1
|
||||
|
||||
self.pending_incomplete_tokens = []
|
||||
|
||||
def fire_at_boundary(self, chunked_encoder_feature: torch.Tensor):
|
||||
if self.always_fire: return True
|
||||
if self.never_fire: return False
|
||||
return fire_at_boundary(chunked_encoder_feature, self.CIFLinear)
|
||||
|
||||
|
||||
def _current_tokens(self):
|
||||
|
||||
toks = self.tokens
|
||||
# very first infer: duplicate start of seq to beam_size
|
||||
if toks[0].shape[0] == 1:
|
||||
toks[0] = toks[0].repeat_interleave(self.cfg.beam_size,dim=0)
|
||||
|
||||
if not self.context.is_empty():
|
||||
context_toks = self.context.as_tensor_beam(self.cfg.beam_size, device=self.model.device)
|
||||
toks = [context_toks] + toks
|
||||
|
||||
# make it one tensor
|
||||
if len(toks) > 1:
|
||||
current_tokens = torch.cat(toks, dim=1)
|
||||
else:
|
||||
current_tokens = toks[0]
|
||||
logger.debug("debug print current_tokens:")
|
||||
self.debug_print_tokens(current_tokens)
|
||||
return current_tokens
|
||||
|
||||
|
||||
def debug_print_tokens(self, tokens):
|
||||
for i in range(self.cfg.beam_size):
|
||||
logger.debug(self.tokenizer.decode_with_timestamps(tokens[i].tolist()))
|
||||
|
||||
### audio buffer
|
||||
|
||||
def segments_len(self):
|
||||
segments_len = sum(s.shape[0] for s in self.segments) / 16000
|
||||
return segments_len
|
||||
|
||||
def _apply_minseglen(self):
|
||||
segments_len = self.segments_len()
|
||||
# wait for long enough audio to start
|
||||
if segments_len < self.cfg.audio_min_len:
|
||||
logger.debug("waiting for next segment")
|
||||
return False
|
||||
return True
|
||||
|
||||
def insert_audio(self, segment=None):
|
||||
if segment is not None:
|
||||
self.segments.append(segment)
|
||||
|
||||
removed_len = 0
|
||||
# len of audio is bigger than buffer_len. Going to remove the first segment
|
||||
segments_len = self.segments_len()
|
||||
while len(self.segments) > 1 and segments_len > self.cfg.audio_max_len:
|
||||
removed_len = self.segments[0].shape[0] / 16000
|
||||
segments_len -= removed_len
|
||||
self.last_attend_frame -= int(TOKENS_PER_SECOND*removed_len)
|
||||
self.cumulative_time_offset += removed_len # Track cumulative time removed
|
||||
self.segments = self.segments[1:]
|
||||
logger.debug(f"remove segments: {len(self.segments)} {len(self.tokens)}, cumulative offset: {self.cumulative_time_offset:.2f}s")
|
||||
if len(self.tokens) > 1:
|
||||
self.context.append_token_ids(self.tokens[1][0,:].tolist())
|
||||
self.tokens = [self.initial_tokens] + self.tokens[2:]
|
||||
return removed_len
|
||||
|
||||
def _clean_cache(self):
|
||||
'''clean the cache that stores the attention matrices and kv_cache.
|
||||
It must be called every time after generation with the model.'''
|
||||
# cleaning cache
|
||||
self.dec_attns = []
|
||||
self.kv_cache = {}
|
||||
if self.decoder_type == "beam":
|
||||
self.inference.kv_cache = self.kv_cache
|
||||
self.token_decoder.reset()
|
||||
|
||||
@torch.no_grad()
|
||||
def lang_id(self, encoder_features):
|
||||
"""Language detection from encoder features.
|
||||
This code is trimmed and copy-pasted from whisper.decoding.detect_language .
|
||||
"""
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
n_audio = encoder_features.shape[0]
|
||||
x = torch.tensor([[self.tokenizer.sot]] * n_audio).to(self.model.device) # [n_audio, 1]
|
||||
logits = self.model.logits(x, encoder_features)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(self.tokenizer.all_language_tokens)] = False
|
||||
logits[:, mask] = -np.inf
|
||||
language_tokens = logits.argmax(dim=-1)
|
||||
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
for j, c in zip(self.tokenizer.all_language_tokens, self.tokenizer.all_language_codes)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
|
||||
single = encoder_features.ndim == 2
|
||||
if single:
|
||||
language_tokens = language_tokens[0]
|
||||
language_probs = language_probs[0]
|
||||
|
||||
self._clean_cache()
|
||||
return language_tokens, language_probs
|
||||
|
||||
### transcription / translation
|
||||
|
||||
@torch.no_grad()
|
||||
def infer(self, is_last=False):
|
||||
new_segment = True
|
||||
if len(self.segments) == 0:
|
||||
logger.debug("No segments, nothing to do")
|
||||
return []
|
||||
if not self._apply_minseglen():
|
||||
logger.debug(f"applied minseglen {self.cfg.audio_min_len} > {self.segments_len()}.")
|
||||
input_segments = torch.cat(self.segments, dim=0)
|
||||
return []
|
||||
|
||||
# input_segments is concatenation of audio, it's one array
|
||||
if len(self.segments) > 1:
|
||||
input_segments = torch.cat(self.segments, dim=0)
|
||||
else:
|
||||
input_segments = self.segments[0]
|
||||
|
||||
beg_encode = time()
|
||||
if self.use_mlcore:
|
||||
coreml_encoder, coreml_input_name, coreml_output_name = self.coreml_encoder_tuple
|
||||
mel_padded = log_mel_spectrogram(
|
||||
input_segments,
|
||||
n_mels=self.model.dims.n_mels,
|
||||
padding=N_SAMPLES,
|
||||
device="cpu",
|
||||
).unsqueeze(0)
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2]) / 2)
|
||||
mel_np = np.ascontiguousarray(mel.numpy())
|
||||
ml_inputs = {coreml_input_name or "mel": mel_np}
|
||||
coreml_outputs = coreml_encoder.predict(ml_inputs)
|
||||
if coreml_output_name and coreml_output_name in coreml_outputs:
|
||||
encoder_feature_np = coreml_outputs[coreml_output_name]
|
||||
else:
|
||||
encoder_feature_np = next(iter(coreml_outputs.values()))
|
||||
encoder_feature = torch.as_tensor(
|
||||
np.array(encoder_feature_np),
|
||||
device=self.device,
|
||||
)
|
||||
if self.mlx_encoder:
|
||||
mlx_mel_padded = mlx_log_mel_spectrogram(audio=input_segments.detach(), n_mels=self.model.dims.n_mels, padding=N_SAMPLES)
|
||||
mlx_mel = mlx_pad_or_trim(mlx_mel_padded, N_FRAMES, axis=-2)
|
||||
mlx_encoder_feature = self.mlx_encoder.encoder(mlx_mel[None])
|
||||
encoder_feature = torch.as_tensor(mlx_encoder_feature)
|
||||
content_mel_len = int((mlx_mel_padded.shape[0] - mlx_mel.shape[0])/2)
|
||||
elif self.fw_encoder:
|
||||
audio_length_seconds = len(input_segments) / 16000
|
||||
content_mel_len = int(audio_length_seconds * 100)//2
|
||||
mel_padded_2 = self.fw_feature_extractor(waveform=input_segments.numpy(), padding=N_SAMPLES)[None, :]
|
||||
mel = fw_pad_or_trim(mel_padded_2, N_FRAMES, axis=-1)
|
||||
encoder_feature_ctranslate = self.fw_encoder.encode(mel)
|
||||
if self.device == 'cpu': #it seems that on gpu, passing StorageView to torch.as_tensor fails and wrapping in the array works
|
||||
encoder_feature_ctranslate = np.array(encoder_feature_ctranslate)
|
||||
try:
|
||||
encoder_feature = torch.as_tensor(encoder_feature_ctranslate, device=self.device)
|
||||
except TypeError: # Normally the cpu condition should prevent having exceptions, but just in case:
|
||||
encoder_feature = torch.as_tensor(np.array(encoder_feature_ctranslate), device=self.device)
|
||||
else:
|
||||
# mel + padding to 30s
|
||||
mel_padded = log_mel_spectrogram(input_segments, n_mels=self.model.dims.n_mels, padding=N_SAMPLES,
|
||||
device=self.device).unsqueeze(0)
|
||||
# trim to 3000
|
||||
mel = pad_or_trim(mel_padded, N_FRAMES)
|
||||
# the len of actual audio
|
||||
content_mel_len = int((mel_padded.shape[2] - mel.shape[2])/2)
|
||||
encoder_feature = self.model.encoder(mel)
|
||||
end_encode = time()
|
||||
# print('Encoder duration:', end_encode-beg_encode)
|
||||
|
||||
if self.cfg.language == "auto" and self.detected_language is None and self.first_timestamp:
|
||||
seconds_since_start = self.segments_len() - self.first_timestamp
|
||||
if seconds_since_start >= 2.0:
|
||||
language_tokens, language_probs = self.lang_id(encoder_feature)
|
||||
top_lan, p = max(language_probs[0].items(), key=lambda x: x[1])
|
||||
print(f"Detected language: {top_lan} with p={p:.4f}")
|
||||
self.create_tokenizer(top_lan)
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
self.cumulative_time_offset = 0.0
|
||||
self.init_tokens()
|
||||
self.init_context()
|
||||
self.detected_language = top_lan
|
||||
logger.info(f"Tokenizer language: {self.tokenizer.language}, {self.tokenizer.sot_sequence_including_notimestamps}")
|
||||
|
||||
self.trim_context()
|
||||
current_tokens = self._current_tokens()
|
||||
|
||||
fire_detected = self.fire_at_boundary(encoder_feature[:, :content_mel_len, :])
|
||||
|
||||
|
||||
sum_logprobs = torch.zeros(self.cfg.beam_size, device=self.device)
|
||||
completed = False
|
||||
# punctuation_stop = False
|
||||
|
||||
attn_of_alignment_heads = None
|
||||
most_attended_frame = None
|
||||
|
||||
token_len_before_decoding = current_tokens.shape[1]
|
||||
|
||||
l_absolute_timestamps = []
|
||||
|
||||
while not completed and current_tokens.shape[1] < self.max_text_len: # bos is 3 tokens
|
||||
|
||||
if new_segment:
|
||||
tokens_for_logits = current_tokens
|
||||
else:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens_for_logits = current_tokens[:,-1:]
|
||||
|
||||
logits = self.logits(tokens_for_logits, encoder_feature) # B, len(tokens), token dict size
|
||||
|
||||
if new_segment and self.tokenizer.no_speech is not None:
|
||||
probs_at_sot = logits[:, self.sot_index, :].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
if no_speech_probs[0] > self.cfg.nonspeech_prob:
|
||||
logger.info("no speech, stop")
|
||||
break
|
||||
|
||||
logits = logits[:, -1, :] # logits for the last token
|
||||
|
||||
# supress blank tokens only at the beginning of the segment
|
||||
if new_segment:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
new_segment = False
|
||||
self.suppress_tokens(logits)
|
||||
current_tokens, completed = self.token_decoder.update(current_tokens, logits, sum_logprobs)
|
||||
|
||||
logger.debug(f"Decoding completed: {completed}, sum_logprobs: {sum_logprobs.tolist()}, tokens: ")
|
||||
self.debug_print_tokens(current_tokens)
|
||||
|
||||
attn_of_alignment_heads = [[] for _ in range(self.num_align_heads)]
|
||||
for i, attn_mat in enumerate(self.dec_attns):
|
||||
layer_rank = int(i % len(self.model.decoder.blocks))
|
||||
align_heads_in_layer = self.align_source.get(layer_rank, [])
|
||||
if len(align_heads_in_layer) == 0:
|
||||
continue
|
||||
for align_head_rank, head_id in align_heads_in_layer:
|
||||
if self.cfg.beam_size == 1:
|
||||
a = attn_mat[head_id, :, :]
|
||||
a = a.unsqueeze(0)
|
||||
else:
|
||||
a = attn_mat[:, head_id, :, :]
|
||||
attn_of_alignment_heads[align_head_rank].append(a)
|
||||
tmp = []
|
||||
for mat in attn_of_alignment_heads:
|
||||
t = torch.cat(mat, dim=1)
|
||||
tmp.append(t)
|
||||
attn_of_alignment_heads = torch.stack(tmp, dim=1)
|
||||
std, mean = torch.std_mean(attn_of_alignment_heads, dim=-2, keepdim=True, unbiased=False)
|
||||
attn_of_alignment_heads = (attn_of_alignment_heads - mean) / std
|
||||
attn_of_alignment_heads = median_filter(attn_of_alignment_heads, 7) # from whisper.timing
|
||||
attn_of_alignment_heads = attn_of_alignment_heads.mean(dim=1)
|
||||
attn_of_alignment_heads = attn_of_alignment_heads[:,:, :content_mel_len]
|
||||
|
||||
# for each beam, the most attended frame is:
|
||||
most_attended_frames = torch.argmax(attn_of_alignment_heads[:,-1,:], dim=-1)
|
||||
|
||||
# Calculate absolute timestamps accounting for cumulative offset
|
||||
absolute_timestamps = [(frame * 0.02 + self.cumulative_time_offset) for frame in most_attended_frames.tolist()]
|
||||
|
||||
logger.debug(str(most_attended_frames.tolist()) + " most att frames")
|
||||
logger.debug(f"Absolute timestamps: {absolute_timestamps} (offset: {self.cumulative_time_offset:.2f}s)")
|
||||
|
||||
most_attended_frame = most_attended_frames[0].item()
|
||||
l_absolute_timestamps.append(absolute_timestamps[0])
|
||||
|
||||
logger.debug("current tokens" + str(current_tokens.shape))
|
||||
if completed:
|
||||
# # stripping the last token, the eot
|
||||
current_tokens = current_tokens[:, :-1]
|
||||
break
|
||||
|
||||
# for some rare cases where the attention fails
|
||||
if not is_last and self.last_attend_frame - most_attended_frame > self.cfg.rewind_threshold:
|
||||
# TODO: check this
|
||||
if current_tokens.shape[1] > 1 and current_tokens[0, -2] >= DEC_PAD:
|
||||
logger.debug("ommit rewinding from special tokens")
|
||||
self.last_attend_frame = most_attended_frame
|
||||
else:
|
||||
logger.debug(
|
||||
f"[rewind detected] current attention pos: {most_attended_frame}, "
|
||||
f"last attention pos: {self.last_attend_frame}; omit this segment")
|
||||
self.last_attend_frame = -self.cfg.rewind_threshold
|
||||
current_tokens = torch.cat(self.tokens, dim=1) if len(self.tokens) > 0 else self.tokens[0]
|
||||
break
|
||||
else:
|
||||
self.last_attend_frame = most_attended_frame
|
||||
|
||||
if content_mel_len - most_attended_frame <= (4 if is_last else self.cfg.frame_threshold):
|
||||
logger.debug(f"attention reaches the end: {most_attended_frame}/{content_mel_len}")
|
||||
# stripping the last token, the one that is attended too close to the end
|
||||
current_tokens = current_tokens[:, :-1]
|
||||
break
|
||||
|
||||
# debug print
|
||||
for i in range(self.cfg.beam_size):
|
||||
logger.debug("attn: {}, current pos: {}, current token: {}({})".format(
|
||||
attn_of_alignment_heads.shape if attn_of_alignment_heads is not None else None,
|
||||
most_attended_frames[i],
|
||||
current_tokens[i, -1].item(),
|
||||
self.tokenizer.decode([current_tokens[i, -1].item()])
|
||||
))
|
||||
|
||||
tokens_to_split = current_tokens[0, token_len_before_decoding:]
|
||||
|
||||
# Prepend pending tokens from previous chunk if any
|
||||
if self.pending_incomplete_tokens:
|
||||
logger.debug(f"[UTF-8 Fix] Prepending {len(self.pending_incomplete_tokens)} pending tokens: {self.pending_incomplete_tokens}")
|
||||
pending_tensor = torch.tensor(self.pending_incomplete_tokens, dtype=torch.long, device=self.device)
|
||||
tokens_to_split = torch.cat([pending_tensor, tokens_to_split])
|
||||
|
||||
if fire_detected or is_last: #or punctuation_stop:
|
||||
new_hypothesis = tokens_to_split.flatten().tolist()
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(new_hypothesis)
|
||||
else:
|
||||
# going to truncate the tokens after the last space
|
||||
split_words, split_tokens = self.tokenizer.split_to_word_tokens(tokens_to_split.tolist())
|
||||
if len(split_words) > 1:
|
||||
new_hypothesis = [i for sublist in split_tokens[:-1] for i in sublist]
|
||||
else:
|
||||
new_hypothesis = []
|
||||
|
||||
|
||||
logger.debug(f"new_hypothesis: {new_hypothesis}")
|
||||
new_tokens = torch.tensor([new_hypothesis], dtype=torch.long).repeat_interleave(self.cfg.beam_size, dim=0).to(
|
||||
device=self.device,
|
||||
)
|
||||
self.tokens.append(new_tokens)
|
||||
|
||||
logger.info(f"Output: {self.tokenizer.decode(new_hypothesis)}")
|
||||
|
||||
self._clean_cache()
|
||||
|
||||
if len(l_absolute_timestamps) >=2 and self.first_timestamp is None:
|
||||
self.first_timestamp = l_absolute_timestamps[0]
|
||||
|
||||
|
||||
timestamped_words = []
|
||||
timestamp_idx = 0
|
||||
replacement_char = "\ufffd"
|
||||
for word, word_tokens in zip(split_words, split_tokens):
|
||||
# Skip words containing incomplete UTF-8 from client output
|
||||
if replacement_char in word:
|
||||
logger.warning(f"[UTF-8 Filter] Skipping incomplete word from client output: {repr(word)}")
|
||||
timestamp_idx += len(word_tokens)
|
||||
continue
|
||||
|
||||
try:
|
||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||
except:
|
||||
pass
|
||||
timestamp_idx += len(word_tokens)
|
||||
|
||||
timestamp_entry = ASRToken(
|
||||
start=round(current_timestamp, 2),
|
||||
end=round(current_timestamp + 0.1, 2),
|
||||
text= word,
|
||||
speaker=self.speaker,
|
||||
detected_language=self.detected_language
|
||||
).with_offset(
|
||||
self.global_time_offset
|
||||
)
|
||||
timestamped_words.append(timestamp_entry)
|
||||
|
||||
# Hold incomplete tokens for next chunk
|
||||
self.pending_incomplete_tokens = []
|
||||
if split_words and replacement_char in split_words[-1]:
|
||||
self.pending_incomplete_tokens = split_tokens[-1]
|
||||
logger.warning(f"[UTF-8 Fix] Holding {len(self.pending_incomplete_tokens)} incomplete tokens for next chunk: {self.pending_incomplete_tokens}")
|
||||
|
||||
return timestamped_words
|
||||
93
whisperlivekit/simul_whisper/token_buffer.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import sys
|
||||
class TokenBuffer:
|
||||
|
||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
||||
self.text = text
|
||||
self.prefix_token_ids = prefix_token_ids
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.pending_token_ids = []
|
||||
|
||||
def as_token_ids(self, tokenizer=None):
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer is not set.")
|
||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
||||
|
||||
def as_tensor(self, device=None):
|
||||
if device is None:
|
||||
device = self.device
|
||||
if device is None:
|
||||
raise ValueError("Device is not set.")
|
||||
tok_ids = self.as_token_ids()
|
||||
return torch.tensor(tok_ids,
|
||||
dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
def as_tensor_beam(self, beam, device=None):
|
||||
t = self.as_tensor(device=device)
|
||||
return t.repeat_interleave(beam, dim=0)
|
||||
|
||||
|
||||
def as_text(self):
|
||||
return self.text
|
||||
|
||||
@staticmethod
|
||||
def empty(*a, **kw):
|
||||
return TokenBuffer(*a,**kw)
|
||||
|
||||
@staticmethod
|
||||
def from_text(text, *a, **kw):
|
||||
return TokenBuffer(*a, text=text, **kw)
|
||||
|
||||
def is_empty(self):
|
||||
return self.text is None or self.text == ""
|
||||
|
||||
def trim_words(self, num=1, after=0):
|
||||
'''
|
||||
num: how many words to trim from the beginning
|
||||
after: how many characters to skip (length of the static prompt)
|
||||
'''
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
|
||||
ids = tokenizer.encode(self.text[after:])
|
||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
||||
# print(words, file=sys.stderr)
|
||||
# print(wids, file=sys.stderr)
|
||||
if not words:
|
||||
return 0
|
||||
self.text = self.text[:after] + "".join(words[num:])
|
||||
return sum(len(wi) for wi in wids[:num])
|
||||
|
||||
def append_token_ids(self, token_ids):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
|
||||
all_tokens = self.pending_token_ids + token_ids
|
||||
|
||||
decoded = tokenizer.decode(all_tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
if replacement_char in decoded:
|
||||
if len(all_tokens) > 1:
|
||||
decoded_partial = tokenizer.decode(all_tokens[:-1])
|
||||
|
||||
if replacement_char not in decoded_partial:
|
||||
self.text += decoded_partial
|
||||
self.pending_token_ids = [all_tokens[-1]]
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.pending_token_ids = all_tokens
|
||||
else:
|
||||
self.text += decoded
|
||||
self.pending_token_ids = []
|
||||
|
||||
def as_split_word_tokens(self):
|
||||
tokenizer = self.tokenizer
|
||||
assert tokenizer is not None, "Tokenizer is not set."
|
||||
ids = tokenizer.encode(self.text)
|
||||
return tokenizer.split_to_word_tokens(ids)
|
||||
203
whisperlivekit/timed_objects.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Any, List
|
||||
from datetime import timedelta
|
||||
|
||||
PUNCTUATION_MARKS = {'.', '!', '?', '。', '!', '?'}
|
||||
|
||||
def format_time(seconds: float) -> str:
|
||||
"""Format seconds as HH:MM:SS."""
|
||||
return str(timedelta(seconds=int(seconds)))
|
||||
|
||||
@dataclass
|
||||
class Timed:
|
||||
start: Optional[float] = 0
|
||||
end: Optional[float] = 0
|
||||
|
||||
@dataclass
|
||||
class TimedText(Timed):
|
||||
text: Optional[str] = ''
|
||||
speaker: Optional[int] = -1
|
||||
detected_language: Optional[str] = None
|
||||
|
||||
def is_punctuation(self):
|
||||
return self.text.strip() in PUNCTUATION_MARKS
|
||||
|
||||
def overlaps_with(self, other: 'TimedText') -> bool:
|
||||
return not (self.end <= other.start or other.end <= self.start)
|
||||
|
||||
def is_within(self, other: 'TimedText') -> bool:
|
||||
return other.contains_timespan(self)
|
||||
|
||||
def duration(self) -> float:
|
||||
return self.end - self.start
|
||||
|
||||
def contains_time(self, time: float) -> bool:
|
||||
return self.start <= time <= self.end
|
||||
|
||||
def contains_timespan(self, other: 'TimedText') -> bool:
|
||||
return self.start <= other.start and self.end >= other.end
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.text)
|
||||
|
||||
|
||||
@dataclass()
|
||||
class ASRToken(TimedText):
|
||||
|
||||
corrected_speaker: Optional[int] = -1
|
||||
validated_speaker: bool = False
|
||||
validated_text: bool = False
|
||||
validated_language: bool = False
|
||||
|
||||
def with_offset(self, offset: float) -> "ASRToken":
|
||||
"""Return a new token with the time offset added."""
|
||||
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, detected_language=self.detected_language)
|
||||
|
||||
@dataclass
|
||||
class Sentence(TimedText):
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Transcript(TimedText):
|
||||
"""
|
||||
represents a concatenation of several ASRToken
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_tokens(
|
||||
cls,
|
||||
tokens: List[ASRToken],
|
||||
sep: Optional[str] = None,
|
||||
offset: float = 0
|
||||
) -> "Transcript":
|
||||
sep = sep if sep is not None else ' '
|
||||
text = sep.join(token.text for token in tokens)
|
||||
if tokens:
|
||||
start = offset + tokens[0].start
|
||||
end = offset + tokens[-1].end
|
||||
else:
|
||||
start = None
|
||||
end = None
|
||||
return cls(start, end, text)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerSegment(Timed):
|
||||
"""Represents a segment of audio attributed to a specific speaker.
|
||||
No text nor probability is associated with this segment.
|
||||
"""
|
||||
speaker: Optional[int] = -1
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Translation(TimedText):
|
||||
pass
|
||||
|
||||
def approximate_cut_at(self, cut_time):
|
||||
"""
|
||||
Each word in text is considered to be of duration (end-start)/len(words in text)
|
||||
"""
|
||||
if not self.text or not self.contains_time(cut_time):
|
||||
return self, None
|
||||
|
||||
words = self.text.split()
|
||||
num_words = len(words)
|
||||
if num_words == 0:
|
||||
return self, None
|
||||
|
||||
duration_per_word = self.duration() / num_words
|
||||
|
||||
cut_word_index = int((cut_time - self.start) / duration_per_word)
|
||||
|
||||
if cut_word_index >= num_words:
|
||||
cut_word_index = num_words -1
|
||||
|
||||
text0 = " ".join(words[:cut_word_index])
|
||||
text1 = " ".join(words[cut_word_index:])
|
||||
|
||||
segment0 = Translation(start=self.start, end=cut_time, text=text0)
|
||||
segment1 = Translation(start=cut_time, end=self.end, text=text1)
|
||||
|
||||
return segment0, segment1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Silence():
|
||||
duration: Optional[float] = None
|
||||
is_starting: bool = False
|
||||
has_ended: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Line(TimedText):
|
||||
translation: str = ''
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
'speaker': int(self.speaker) if self.speaker != -1 else 1,
|
||||
'text': self.text,
|
||||
'start': format_time(self.start),
|
||||
'end': format_time(self.end),
|
||||
}
|
||||
if self.translation:
|
||||
_dict['translation'] = self.translation
|
||||
if self.detected_language:
|
||||
_dict['detected_language'] = self.detected_language
|
||||
return _dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontData():
|
||||
status: str = ''
|
||||
error: str = ''
|
||||
lines: list[Line] = field(default_factory=list)
|
||||
buffer_transcription: str = ''
|
||||
buffer_diarization: str = ''
|
||||
buffer_translation: str = ''
|
||||
remaining_time_transcription: float = 0.
|
||||
remaining_time_diarization: float = 0.
|
||||
|
||||
def to_dict(self):
|
||||
_dict = {
|
||||
'status': self.status,
|
||||
'lines': [line.to_dict() for line in self.lines if (line.text or line.speaker == -2)],
|
||||
'buffer_transcription': self.buffer_transcription,
|
||||
'buffer_diarization': self.buffer_diarization,
|
||||
'buffer_translation': self.buffer_translation,
|
||||
'remaining_time_transcription': self.remaining_time_transcription,
|
||||
'remaining_time_diarization': self.remaining_time_diarization,
|
||||
}
|
||||
if self.error:
|
||||
_dict['error'] = self.error
|
||||
return _dict
|
||||
|
||||
@dataclass
|
||||
class ChangeSpeaker:
|
||||
speaker: int
|
||||
start: int
|
||||
|
||||
@dataclass
|
||||
class State():
|
||||
tokens: list = field(default_factory=list)
|
||||
last_validated_token: int = 0
|
||||
last_speaker: int = 1
|
||||
last_punctuation_index: Optional[int] = None
|
||||
translation_validated_segments: list = field(default_factory=list)
|
||||
buffer_translation: str = field(default_factory=Transcript)
|
||||
buffer_transcription: str = field(default_factory=Transcript)
|
||||
diarization_segments: list = field(default_factory=list)
|
||||
end_buffer: float = 0.0
|
||||
end_attributed_speaker: float = 0.0
|
||||
remaining_time_transcription: float = 0.0
|
||||
remaining_time_diarization: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateLight():
|
||||
new_tokens: list = field(default_factory=list)
|
||||
new_translation: list = field(default_factory=list)
|
||||
new_diarization: list = field(default_factory=list)
|
||||
new_tokens_buffer: list = field(default_factory=list) #only when local agreement
|
||||
new_tokens_index = 0
|
||||
new_translation_index = 0
|
||||
new_diarization_index = 0
|
||||
60
whisperlivekit/trail_repetition.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Sequence, Callable, Any, Optional, Dict
|
||||
|
||||
def _detect_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x, # extract comparable value
|
||||
min_block: int = 1, # set to 2 to ignore 1-token loops like "."
|
||||
max_tail: int = 300, # search window from the end for speed
|
||||
prefer: str = "longest", # "longest" coverage or "smallest" block
|
||||
) -> Optional[Dict]:
|
||||
vals = [key(x) for x in seq][-max_tail:]
|
||||
n = len(vals)
|
||||
best = None
|
||||
|
||||
# try every possible block length
|
||||
for b in range(min_block, n // 2 + 1):
|
||||
block = vals[-b:]
|
||||
# count how many times this block repeats contiguously at the very end
|
||||
count, i = 0, n
|
||||
while i - b >= 0 and vals[i - b:i] == block:
|
||||
count += 1
|
||||
i -= b
|
||||
|
||||
if count >= 2:
|
||||
cand = {
|
||||
"block_size": b,
|
||||
"count": count,
|
||||
"start_index": len(seq) - count * b, # in original seq
|
||||
"end_index": len(seq),
|
||||
}
|
||||
if (best is None or
|
||||
(prefer == "longest" and count * b > best["count"] * best["block_size"]) or
|
||||
(prefer == "smallest" and b < best["block_size"])):
|
||||
best = cand
|
||||
return best
|
||||
|
||||
def trim_tail_repetition(
|
||||
seq: Sequence[Any],
|
||||
key: Callable[[Any], Any] = lambda x: x,
|
||||
min_block: int = 1,
|
||||
max_tail: int = 300,
|
||||
prefer: str = "longest",
|
||||
keep: int = 1, # how many copies of the repeating block to keep at the end (0 or 1 are common)
|
||||
):
|
||||
"""
|
||||
Returns a new sequence with repeated tail trimmed.
|
||||
keep=1 -> keep a single copy of the repeated block.
|
||||
keep=0 -> remove all copies of the repeated block.
|
||||
"""
|
||||
rep = _detect_tail_repetition(seq, key, min_block, max_tail, prefer)
|
||||
if not rep:
|
||||
return seq, False # nothing to trim
|
||||
|
||||
b, c = rep["block_size"], rep["count"]
|
||||
if keep < 0:
|
||||
keep = 0
|
||||
if keep >= c:
|
||||
return seq, False # nothing to trim (already <= keep copies)
|
||||
# new length = total - (copies_to_remove * block_size)
|
||||
new_len = len(seq) - (c - keep) * b
|
||||
return seq[:new_len], True
|
||||
0
whisperlivekit/vad_models/__init__.py
Normal file
BIN
whisperlivekit/vad_models/silero_vad.jit
Normal file
BIN
whisperlivekit/vad_models/silero_vad.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_16k_op15.onnx
Normal file
BIN
whisperlivekit/vad_models/silero_vad_half.onnx
Normal file
51
whisperlivekit/warmup.py
Normal file
@@ -0,0 +1,51 @@
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_file(warmup_file=None, timeout=5):
|
||||
import os
|
||||
import tempfile
|
||||
import urllib.request
|
||||
import librosa
|
||||
|
||||
if warmup_file == "":
|
||||
logger.info(f"Skipping warmup.")
|
||||
return None
|
||||
|
||||
# Download JFK sample if not already present
|
||||
if warmup_file is None:
|
||||
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
||||
temp_dir = tempfile.gettempdir()
|
||||
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
||||
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||
try:
|
||||
logger.debug(f"Downloading warmup file from {jfk_url}")
|
||||
with urllib.request.urlopen(jfk_url, timeout=timeout) as r, open(warmup_file, "wb") as f:
|
||||
f.write(r.read())
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup file download failed: {e}.")
|
||||
return None
|
||||
|
||||
# Validate file and load
|
||||
if not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
||||
logger.warning(f"Warmup file {warmup_file} is invalid or missing.")
|
||||
return None
|
||||
|
||||
try:
|
||||
audio, _ = librosa.load(warmup_file, sr=16000)
|
||||
return audio
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load warmup file: {e}")
|
||||
return None
|
||||
|
||||
def warmup_asr(asr, warmup_file=None, timeout=5):
|
||||
"""
|
||||
Warmup the ASR model by transcribing a short audio file.
|
||||
"""
|
||||
audio = load_file(warmup_file=warmup_file, timeout=timeout)
|
||||
if audio is None:
|
||||
logger.warning("Warmup file unavailable. Skipping ASR warmup.")
|
||||
return
|
||||
asr.transcribe(audio)
|
||||
logger.info("ASR model is warmed up.")
|
||||
0
whisperlivekit/web/__init__.py
Normal file
630
whisperlivekit/web/live_transcription.css
Normal file
@@ -0,0 +1,630 @@
|
||||
:root {
|
||||
--bg: #ffffff;
|
||||
--text: #111111;
|
||||
--muted: #666666;
|
||||
--border: #e5e5e5;
|
||||
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||
--chip-text: #000000;
|
||||
--spinner-border: #8d8d8d5c;
|
||||
--spinner-top: #b0b0b0;
|
||||
--silence-bg: #f3f3f3;
|
||||
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||
--button-bg: #ffffff;
|
||||
--button-border: #e9e9e9;
|
||||
--wave-stroke: #000000;
|
||||
--label-dia-text: #868686;
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root:not([data-theme="light"]) {
|
||||
--bg: #0b0b0b;
|
||||
--text: #e6e6e6;
|
||||
--muted: #9aa0a6;
|
||||
--border: #333333;
|
||||
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||
--chip-text: #e6e6e6;
|
||||
--spinner-border: #555555;
|
||||
--spinner-top: #dddddd;
|
||||
--silence-bg: #1a1a1a;
|
||||
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||
--button-bg: #111111;
|
||||
--button-border: #333333;
|
||||
--wave-stroke: #e6e6e6;
|
||||
--label-dia-text: #b3b3b3;
|
||||
--label-trans-text: #ffffff;
|
||||
}
|
||||
}
|
||||
|
||||
:root[data-theme="dark"] {
|
||||
--bg: #0b0b0b;
|
||||
--text: #e6e6e6;
|
||||
--muted: #9aa0a6;
|
||||
--border: #333333;
|
||||
--chip-bg: rgba(255, 255, 255, 0.08);
|
||||
--chip-text: #e6e6e6;
|
||||
--spinner-border: #555555;
|
||||
--spinner-top: #dddddd;
|
||||
--silence-bg: #1a1a1a;
|
||||
--loading-bg: rgba(255, 77, 77, 0.12);
|
||||
--button-bg: #111111;
|
||||
--button-border: #333333;
|
||||
--wave-stroke: #e6e6e6;
|
||||
--label-dia-text: #b3b3b3;
|
||||
--label-trans-text: #ffffff;
|
||||
}
|
||||
|
||||
:root[data-theme="light"] {
|
||||
--bg: #ffffff;
|
||||
--text: #111111;
|
||||
--muted: #666666;
|
||||
--border: #e5e5e5;
|
||||
--chip-bg: rgba(0, 0, 0, 0.04);
|
||||
--chip-text: #000000;
|
||||
--spinner-border: #8d8d8d5c;
|
||||
--spinner-top: #b0b0b0;
|
||||
--silence-bg: #f3f3f3;
|
||||
--loading-bg: rgba(255, 77, 77, 0.06);
|
||||
--button-bg: #ffffff;
|
||||
--button-border: #e9e9e9;
|
||||
--wave-stroke: #000000;
|
||||
--label-dia-text: #868686;
|
||||
--label-trans-text: #111111;
|
||||
}
|
||||
|
||||
html.is-extension
|
||||
{
|
||||
width: 350px;
|
||||
height: 500px;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-sans-serif, system-ui, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol', 'Noto Color Emoji';
|
||||
margin: 0;
|
||||
text-align: center;
|
||||
background-color: var(--bg);
|
||||
color: var(--text);
|
||||
height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
/* Record button */
|
||||
#recordButton {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: var(--button-bg);
|
||||
cursor: pointer;
|
||||
transition: all 0.3s ease;
|
||||
border: 1px solid var(--button-border);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
#recordButton.recording {
|
||||
width: 180px;
|
||||
border-radius: 40px;
|
||||
justify-content: flex-start;
|
||||
padding-left: 20px;
|
||||
}
|
||||
|
||||
#recordButton:active {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.shape-container {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.shape {
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
background-color: rgb(209, 61, 53);
|
||||
border-radius: 50%;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
#recordButton:disabled .shape {
|
||||
background-color: #6e6d6d;
|
||||
}
|
||||
|
||||
#recordButton.recording .shape {
|
||||
border-radius: 5px;
|
||||
width: 25px;
|
||||
height: 25px;
|
||||
}
|
||||
|
||||
/* Recording elements */
|
||||
.recording-info {
|
||||
display: none;
|
||||
align-items: center;
|
||||
margin-left: 15px;
|
||||
flex-grow: 1;
|
||||
}
|
||||
|
||||
#recordButton.recording .recording-info {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.wave-container {
|
||||
width: 60px;
|
||||
height: 30px;
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
#waveCanvas {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.timer {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: var(--text);
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
#status {
|
||||
margin-top: 15px;
|
||||
font-size: 16px;
|
||||
color: var(--text);
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.header-container {
|
||||
position: sticky;
|
||||
top: 0;
|
||||
background-color: var(--bg);
|
||||
z-index: 100;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
/* Settings */
|
||||
.settings-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
position: relative;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.settings-toggle {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: none;
|
||||
border-radius: 50%;
|
||||
background-color: var(--button-bg);
|
||||
border: 1px solid var(--button-border);
|
||||
cursor: pointer;
|
||||
display: none;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.settings-toggle:hover {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.settings-toggle.active {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.settings-toggle img {
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
}
|
||||
|
||||
@media (max-width: 10000px) {
|
||||
.settings-toggle {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.settings {
|
||||
display: none;
|
||||
background: var(--bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 18px;
|
||||
padding: 12px;
|
||||
}
|
||||
|
||||
.settings.visible {
|
||||
display: flex;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.settings-container {
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 15px;
|
||||
}
|
||||
}
|
||||
|
||||
.field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 3px;
|
||||
}
|
||||
|
||||
#chunkSelector,
|
||||
#websocketInput,
|
||||
#themeSelector,
|
||||
#microphoneSelect {
|
||||
font-size: 16px;
|
||||
padding: 5px 8px;
|
||||
border-radius: 8px;
|
||||
border: 1px solid var(--border);
|
||||
background-color: var(--button-bg);
|
||||
color: var(--text);
|
||||
max-height: 30px;
|
||||
}
|
||||
|
||||
#microphoneSelect {
|
||||
width: 100%;
|
||||
max-width: 190px;
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
#chunkSelector:focus,
|
||||
#websocketInput:focus,
|
||||
#themeSelector:focus,
|
||||
#microphoneSelect:focus {
|
||||
outline: none;
|
||||
border-color: #007bff;
|
||||
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.15);
|
||||
}
|
||||
|
||||
label {
|
||||
font-size: 13px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.ws-default {
|
||||
font-size: 12px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
/* Segmented pill control for Theme */
|
||||
.segmented {
|
||||
display: inline-flex;
|
||||
align-items: stretch;
|
||||
border: 1px solid var(--button-border);
|
||||
background-color: var(--button-bg);
|
||||
border-radius: 999px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.segmented input[type="radio"] {
|
||||
position: absolute;
|
||||
opacity: 0;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-top: 17px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
font-size: 14px;
|
||||
color: var(--muted);
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
transition: background-color 0.2s ease, color 0.2s ease;
|
||||
}
|
||||
|
||||
.segmented label span {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.segmented label:hover span {
|
||||
display: inline;
|
||||
}
|
||||
|
||||
.segmented label:hover {
|
||||
background-color: var(--chip-bg);
|
||||
}
|
||||
|
||||
.segmented img {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
}
|
||||
|
||||
.segmented input[type="radio"]:checked + label {
|
||||
background-color: var(--chip-bg);
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.segmented input[type="radio"]:focus-visible + label,
|
||||
.segmented input[type="radio"]:focus + label {
|
||||
outline: 2px solid #007bff;
|
||||
outline-offset: 2px;
|
||||
border-radius: 999px;
|
||||
}
|
||||
|
||||
.transcript-container {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 20px;
|
||||
scrollbar-width: none;
|
||||
-ms-overflow-style: none;
|
||||
}
|
||||
|
||||
.transcript-container::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* Transcript area */
|
||||
#linesTranscript {
|
||||
margin: 0 auto;
|
||||
max-width: 700px;
|
||||
text-align: left;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
#linesTranscript p {
|
||||
margin: 0px 0;
|
||||
}
|
||||
|
||||
#linesTranscript strong {
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
#speaker {
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
.label_diarization {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
margin-left: 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: var(--label-dia-text);
|
||||
}
|
||||
|
||||
.label_transcription {
|
||||
background-color: var(--chip-bg);
|
||||
border-radius: 100px;
|
||||
padding: 2px 10px;
|
||||
display: inline-block;
|
||||
white-space: nowrap;
|
||||
margin-left: 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
color: var(--label-trans-text);
|
||||
}
|
||||
|
||||
.label_translation {
|
||||
background-color: var(--chip-bg);
|
||||
display: inline-flex;
|
||||
border-radius: 10px;
|
||||
padding: 4px 8px;
|
||||
margin-top: 4px;
|
||||
font-size: 14px;
|
||||
color: var(--text);
|
||||
align-items: flex-start;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.lag-diarization-value {
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
.label_translation img {
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.label_translation img {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
}
|
||||
|
||||
#timeInfo {
|
||||
color: var(--muted);
|
||||
margin-left: 0px;
|
||||
}
|
||||
|
||||
.textcontent {
|
||||
font-size: 16px;
|
||||
padding-left: 10px;
|
||||
margin-bottom: 10px;
|
||||
margin-top: 1px;
|
||||
padding-top: 5px;
|
||||
border-radius: 0px 0px 0px 10px;
|
||||
}
|
||||
|
||||
.buffer_diarization {
|
||||
color: var(--label-dia-text);
|
||||
}
|
||||
|
||||
.buffer_transcription {
|
||||
color: #7474748c;
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
.buffer_translation {
|
||||
color: #a0a0a0;
|
||||
margin-left: 6px;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
display: inline-block;
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border: 2px solid var(--spinner-border);
|
||||
border-top: 2px solid var(--spinner-top);
|
||||
border-radius: 50%;
|
||||
animation: spin 0.7s linear infinite;
|
||||
vertical-align: middle;
|
||||
margin-bottom: 2px;
|
||||
margin-right: 5px;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.silence {
|
||||
color: var(--muted);
|
||||
background-color: var(--silence-bg);
|
||||
font-size: 13px;
|
||||
border-radius: 30px;
|
||||
padding: 2px 10px;
|
||||
}
|
||||
|
||||
.loading {
|
||||
color: var(--muted);
|
||||
background-color: var(--loading-bg);
|
||||
border-radius: 8px 8px 8px 0px;
|
||||
padding: 2px 10px;
|
||||
font-size: 14px;
|
||||
margin-bottom: 0px;
|
||||
}
|
||||
|
||||
/* for smaller screens */
|
||||
@media (max-width: 200px) {
|
||||
.header-container {
|
||||
padding: 15px;
|
||||
}
|
||||
|
||||
.settings-container {
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.buttons-container {
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.field {
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
min-width: 100px;
|
||||
max-width: 160px;
|
||||
}
|
||||
|
||||
.theme-selector-container {
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.transcript-container {
|
||||
padding: 15px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
.header-container {
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.settings {
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
#websocketInput,
|
||||
#microphoneSelect {
|
||||
max-width: 140px;
|
||||
}
|
||||
|
||||
.segmented label {
|
||||
padding: 4px 8px;
|
||||
font-size: 12px;
|
||||
}
|
||||
|
||||
.segmented img {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
}
|
||||
|
||||
.transcript-container {
|
||||
padding: 10px;
|
||||
}
|
||||
}
|
||||
|
||||
.label_language {
|
||||
background-color: var(--chip-bg);
|
||||
margin-bottom: 0px;
|
||||
border-radius: 100px;
|
||||
padding: 2px 8px;
|
||||
margin-left: 10px;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
font-size: 14px;
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
|
||||
.speaker-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
margin-left: -5px;
|
||||
border-radius: 50%;
|
||||
font-size: 11px;
|
||||
line-height: 1;
|
||||
font-weight: 800;
|
||||
color: var(--muted);
|
||||
}
|
||||
79
whisperlivekit/web/live_transcription.html
Normal file
@@ -0,0 +1,79 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>WhisperLiveKit</title>
|
||||
<link rel="stylesheet" href="live_transcription.css" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="header-container">
|
||||
<div class="settings-container">
|
||||
<div class="buttons-container">
|
||||
<button id="recordButton">
|
||||
<div class="shape-container">
|
||||
<div class="shape"></div>
|
||||
</div>
|
||||
<div class="recording-info">
|
||||
<div class="wave-container">
|
||||
<canvas id="waveCanvas"></canvas>
|
||||
</div>
|
||||
<div class="timer">00:00</div>
|
||||
</div>
|
||||
</button>
|
||||
|
||||
<button id="settingsToggle" class="settings-toggle" title="Show/hide settings">
|
||||
<img src="web/src/settings.svg" alt="Settings" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="settings">
|
||||
<div class="field">
|
||||
<label for="websocketInput">Websocket URL</label>
|
||||
<input id="websocketInput" type="text" placeholder="ws://host:port/asr" />
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<label id="microphoneSelectLabel" for="microphoneSelect">Select Microphone</label>
|
||||
<select id="microphoneSelect">
|
||||
<option value="">Default Microphone</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="theme-selector-container">
|
||||
<div class="segmented" role="radiogroup" aria-label="Theme selector">
|
||||
<input type="radio" id="theme-system" name="theme" value="system" />
|
||||
<label for="theme-system" title="System">
|
||||
<img src="/web/src/system_mode.svg" alt="" />
|
||||
<span>System</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-light" name="theme" value="light" />
|
||||
<label for="theme-light" title="Light">
|
||||
<img src="/web/src/light_mode.svg" alt="" />
|
||||
<span>Light</span>
|
||||
</label>
|
||||
|
||||
<input type="radio" id="theme-dark" name="theme" value="dark" />
|
||||
<label for="theme-dark" title="Dark">
|
||||
<img src="/web/src/dark_mode.svg" alt="" />
|
||||
<span>Dark</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p id="status"></p>
|
||||
</div>
|
||||
|
||||
<div class="transcript-container">
|
||||
<div id="linesTranscript"></div>
|
||||
</div>
|
||||
|
||||
<script src="live_transcription.js"></script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
817
whisperlivekit/web/live_transcription.js
Normal file
@@ -0,0 +1,817 @@
|
||||
const isExtension = typeof chrome !== 'undefined' && chrome.runtime && chrome.runtime.getURL;
|
||||
if (isExtension) {
|
||||
document.documentElement.classList.add('is-extension');
|
||||
}
|
||||
const isWebContext = !isExtension;
|
||||
|
||||
let isRecording = false;
|
||||
let websocket = null;
|
||||
let recorder = null;
|
||||
let chunkDuration = 100;
|
||||
let websocketUrl = "ws://localhost:8000/asr";
|
||||
let userClosing = false;
|
||||
let wakeLock = null;
|
||||
let startTime = null;
|
||||
let timerInterval = null;
|
||||
let audioContext = null;
|
||||
let analyser = null;
|
||||
let microphone = null;
|
||||
let workletNode = null;
|
||||
let recorderWorker = null;
|
||||
let waveCanvas = document.getElementById("waveCanvas");
|
||||
let waveCtx = waveCanvas.getContext("2d");
|
||||
let animationFrame = null;
|
||||
let waitingForStop = false;
|
||||
let lastReceivedData = null;
|
||||
let lastSignature = null;
|
||||
let availableMicrophones = [];
|
||||
let selectedMicrophoneId = null;
|
||||
let serverUseAudioWorklet = null;
|
||||
let configReadyResolve;
|
||||
const configReady = new Promise((r) => (configReadyResolve = r));
|
||||
let outputAudioContext = null;
|
||||
let audioSource = null;
|
||||
|
||||
waveCanvas.width = 60 * (window.devicePixelRatio || 1);
|
||||
waveCanvas.height = 30 * (window.devicePixelRatio || 1);
|
||||
waveCtx.scale(window.devicePixelRatio || 1, window.devicePixelRatio || 1);
|
||||
|
||||
const statusText = document.getElementById("status");
|
||||
const recordButton = document.getElementById("recordButton");
|
||||
const chunkSelector = document.getElementById("chunkSelector");
|
||||
const websocketInput = document.getElementById("websocketInput");
|
||||
const websocketDefaultSpan = document.getElementById("wsDefaultUrl");
|
||||
const linesTranscriptDiv = document.getElementById("linesTranscript");
|
||||
const timerElement = document.querySelector(".timer");
|
||||
const themeRadios = document.querySelectorAll('input[name="theme"]');
|
||||
const microphoneSelect = document.getElementById("microphoneSelect");
|
||||
|
||||
const settingsToggle = document.getElementById("settingsToggle");
|
||||
const settingsDiv = document.querySelector(".settings");
|
||||
|
||||
// if (isExtension) {
|
||||
// chrome.runtime.onInstalled.addListener((details) => {
|
||||
// if (details.reason.search(/install/g) === -1) {
|
||||
// return;
|
||||
// }
|
||||
// chrome.tabs.create({
|
||||
// url: chrome.runtime.getURL("welcome.html"),
|
||||
// active: true
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
|
||||
const translationIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12px" viewBox="0 -960 960 960" width="12px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>`
|
||||
const silenceIcon = `<svg xmlns="http://www.w3.org/2000/svg" style="vertical-align: text-bottom;" height="14px" viewBox="0 -960 960 960" width="14px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>`;
|
||||
const languageIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="12" viewBox="0 -960 960 960" width="12" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>`
|
||||
const speakerIcon = `<svg xmlns="http://www.w3.org/2000/svg" height="16px" style="vertical-align: text-bottom;" viewBox="0 -960 960 960" width="16px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>`;
|
||||
|
||||
function getWaveStroke() {
|
||||
const styles = getComputedStyle(document.documentElement);
|
||||
const v = styles.getPropertyValue("--wave-stroke").trim();
|
||||
return v || "#000";
|
||||
}
|
||||
|
||||
let waveStroke = getWaveStroke();
|
||||
function updateWaveStroke() {
|
||||
waveStroke = getWaveStroke();
|
||||
}
|
||||
|
||||
function applyTheme(pref) {
|
||||
if (pref === "light") {
|
||||
document.documentElement.setAttribute("data-theme", "light");
|
||||
} else if (pref === "dark") {
|
||||
document.documentElement.setAttribute("data-theme", "dark");
|
||||
} else {
|
||||
document.documentElement.removeAttribute("data-theme");
|
||||
}
|
||||
updateWaveStroke();
|
||||
}
|
||||
|
||||
// Persisted theme preference
|
||||
const savedThemePref = localStorage.getItem("themePreference") || "system";
|
||||
applyTheme(savedThemePref);
|
||||
if (themeRadios.length) {
|
||||
themeRadios.forEach((r) => {
|
||||
r.checked = r.value === savedThemePref;
|
||||
r.addEventListener("change", () => {
|
||||
if (r.checked) {
|
||||
localStorage.setItem("themePreference", r.value);
|
||||
applyTheme(r.value);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// React to OS theme changes when in "system" mode
|
||||
const darkMq = window.matchMedia && window.matchMedia("(prefers-color-scheme: dark)");
|
||||
const handleOsThemeChange = () => {
|
||||
const pref = localStorage.getItem("themePreference") || "system";
|
||||
if (pref === "system") updateWaveStroke();
|
||||
};
|
||||
if (darkMq && darkMq.addEventListener) {
|
||||
darkMq.addEventListener("change", handleOsThemeChange);
|
||||
} else if (darkMq && darkMq.addListener) {
|
||||
// deprecated, but included for Safari compatibility
|
||||
darkMq.addListener(handleOsThemeChange);
|
||||
}
|
||||
|
||||
async function enumerateMicrophones() {
|
||||
try {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
stream.getTracks().forEach(track => track.stop());
|
||||
|
||||
const devices = await navigator.mediaDevices.enumerateDevices();
|
||||
availableMicrophones = devices.filter(device => device.kind === 'audioinput');
|
||||
|
||||
populateMicrophoneSelect();
|
||||
console.log(`Found ${availableMicrophones.length} microphone(s)`);
|
||||
} catch (error) {
|
||||
console.error('Error enumerating microphones:', error);
|
||||
statusText.textContent = "Error accessing microphones. Please grant permission.";
|
||||
}
|
||||
}
|
||||
|
||||
function populateMicrophoneSelect() {
|
||||
if (!microphoneSelect) return;
|
||||
|
||||
microphoneSelect.innerHTML = '<option value="">Default Microphone</option>';
|
||||
|
||||
availableMicrophones.forEach((device, index) => {
|
||||
const option = document.createElement('option');
|
||||
option.value = device.deviceId;
|
||||
option.textContent = device.label || `Microphone ${index + 1}`;
|
||||
microphoneSelect.appendChild(option);
|
||||
});
|
||||
|
||||
const savedMicId = localStorage.getItem('selectedMicrophone');
|
||||
if (savedMicId && availableMicrophones.some(mic => mic.deviceId === savedMicId)) {
|
||||
microphoneSelect.value = savedMicId;
|
||||
selectedMicrophoneId = savedMicId;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMicrophoneChange() {
|
||||
selectedMicrophoneId = microphoneSelect.value || null;
|
||||
localStorage.setItem('selectedMicrophone', selectedMicrophoneId || '');
|
||||
|
||||
const selectedDevice = availableMicrophones.find(mic => mic.deviceId === selectedMicrophoneId);
|
||||
const deviceName = selectedDevice ? selectedDevice.label : 'Default Microphone';
|
||||
|
||||
console.log(`Selected microphone: ${deviceName}`);
|
||||
statusText.textContent = `Microphone changed to: ${deviceName}`;
|
||||
|
||||
if (isRecording) {
|
||||
statusText.textContent = "Switching microphone... Please wait.";
|
||||
stopRecording().then(() => {
|
||||
setTimeout(() => {
|
||||
toggleRecording();
|
||||
}, 1000);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Helpers
|
||||
function fmt1(x) {
|
||||
const n = Number(x);
|
||||
return Number.isFinite(n) ? n.toFixed(1) : x;
|
||||
}
|
||||
|
||||
let host, port, protocol;
|
||||
port = 8000;
|
||||
if (isExtension) {
|
||||
host = "localhost";
|
||||
protocol = "ws";
|
||||
} else {
|
||||
host = window.location.hostname || "localhost";
|
||||
port = window.location.port;
|
||||
protocol = window.location.protocol === "https:" ? "wss" : "ws";
|
||||
}
|
||||
const defaultWebSocketUrl = `${protocol}://${host}${port ? ":" + port : ""}/asr`;
|
||||
|
||||
// Populate default caption and input
|
||||
if (websocketDefaultSpan) websocketDefaultSpan.textContent = defaultWebSocketUrl;
|
||||
websocketInput.value = defaultWebSocketUrl;
|
||||
websocketUrl = defaultWebSocketUrl;
|
||||
|
||||
// Optional chunk selector (guard for presence)
|
||||
if (chunkSelector) {
|
||||
chunkSelector.addEventListener("change", () => {
|
||||
chunkDuration = parseInt(chunkSelector.value);
|
||||
});
|
||||
}
|
||||
|
||||
// WebSocket input change handling
|
||||
websocketInput.addEventListener("change", () => {
|
||||
const urlValue = websocketInput.value.trim();
|
||||
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
|
||||
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
|
||||
return;
|
||||
}
|
||||
websocketUrl = urlValue;
|
||||
statusText.textContent = "WebSocket URL updated. Ready to connect.";
|
||||
});
|
||||
|
||||
function setupWebSocket() {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
websocket = new WebSocket(websocketUrl);
|
||||
} catch (error) {
|
||||
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
|
||||
websocket.onopen = () => {
|
||||
statusText.textContent = "Connected to server.";
|
||||
resolve();
|
||||
};
|
||||
|
||||
websocket.onclose = () => {
|
||||
if (userClosing) {
|
||||
if (waitingForStop) {
|
||||
statusText.textContent = "Processing finalized or connection closed.";
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
lastReceivedData.buffer_translation || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
statusText.textContent = "Disconnected from the WebSocket server. (Check logs if model is loading.)";
|
||||
if (isRecording) {
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
isRecording = false;
|
||||
waitingForStop = false;
|
||||
userClosing = false;
|
||||
lastReceivedData = null;
|
||||
websocket = null;
|
||||
updateUI();
|
||||
};
|
||||
|
||||
websocket.onerror = () => {
|
||||
statusText.textContent = "Error connecting to WebSocket.";
|
||||
reject(new Error("Error connecting to WebSocket"));
|
||||
};
|
||||
|
||||
websocket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.type === "config") {
|
||||
serverUseAudioWorklet = !!data.useAudioWorklet;
|
||||
statusText.textContent = serverUseAudioWorklet
|
||||
? "Connected. Using AudioWorklet (PCM)."
|
||||
: "Connected. Using MediaRecorder (WebM).";
|
||||
if (configReadyResolve) configReadyResolve();
|
||||
return;
|
||||
}
|
||||
|
||||
if (data.type === "ready_to_stop") {
|
||||
console.log("Ready to stop received, finalizing display and closing WebSocket.");
|
||||
waitingForStop = false;
|
||||
|
||||
if (lastReceivedData) {
|
||||
renderLinesWithBuffer(
|
||||
lastReceivedData.lines || [],
|
||||
lastReceivedData.buffer_diarization || "",
|
||||
lastReceivedData.buffer_transcription || "",
|
||||
lastReceivedData.buffer_translation || "",
|
||||
0,
|
||||
0,
|
||||
true
|
||||
);
|
||||
}
|
||||
statusText.textContent = "Finished processing audio! Ready to record again.";
|
||||
recordButton.disabled = false;
|
||||
|
||||
if (websocket) {
|
||||
websocket.close();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
lastReceivedData = data;
|
||||
|
||||
const {
|
||||
lines = [],
|
||||
buffer_transcription = "",
|
||||
buffer_diarization = "",
|
||||
buffer_translation = "",
|
||||
remaining_time_transcription = 0,
|
||||
remaining_time_diarization = 0,
|
||||
status = "active_transcription",
|
||||
} = data;
|
||||
|
||||
renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
buffer_translation,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
false,
|
||||
status
|
||||
);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function renderLinesWithBuffer(
|
||||
lines,
|
||||
buffer_diarization,
|
||||
buffer_transcription,
|
||||
buffer_translation,
|
||||
remaining_time_diarization,
|
||||
remaining_time_transcription,
|
||||
isFinalizing = false,
|
||||
current_status = "active_transcription"
|
||||
) {
|
||||
if (current_status === "no_audio_detected") {
|
||||
linesTranscriptDiv.innerHTML =
|
||||
"<p style='text-align: center; color: var(--muted); margin-top: 20px;'><em>No audio detected...</em></p>";
|
||||
return;
|
||||
}
|
||||
|
||||
const showLoading = !isFinalizing && (lines || []).some((it) => it.speaker == 0);
|
||||
const showTransLag = !isFinalizing && remaining_time_transcription > 0;
|
||||
const showDiaLag = !isFinalizing && !!buffer_diarization && remaining_time_diarization > 0;
|
||||
const signature = JSON.stringify({
|
||||
lines: (lines || []).map((it) => ({ speaker: it.speaker, text: it.text, start: it.start, end: it.end, detected_language: it.detected_language })),
|
||||
buffer_transcription: buffer_transcription || "",
|
||||
buffer_diarization: buffer_diarization || "",
|
||||
buffer_translation: buffer_translation,
|
||||
status: current_status,
|
||||
showLoading,
|
||||
showTransLag,
|
||||
showDiaLag,
|
||||
isFinalizing: !!isFinalizing,
|
||||
});
|
||||
if (lastSignature === signature) {
|
||||
const t = document.querySelector(".lag-transcription-value");
|
||||
if (t) t.textContent = fmt1(remaining_time_transcription);
|
||||
const d = document.querySelector(".lag-diarization-value");
|
||||
if (d) d.textContent = fmt1(remaining_time_diarization);
|
||||
const ld = document.querySelector(".loading-diarization-value");
|
||||
if (ld) ld.textContent = fmt1(remaining_time_diarization);
|
||||
return;
|
||||
}
|
||||
lastSignature = signature;
|
||||
|
||||
const linesHtml = (lines || [])
|
||||
.map((item, idx) => {
|
||||
let timeInfo = "";
|
||||
if (item.start !== undefined && item.end !== undefined) {
|
||||
timeInfo = ` ${item.start} - ${item.end}`;
|
||||
}
|
||||
|
||||
let speakerLabel = "";
|
||||
if (item.speaker === -2) {
|
||||
speakerLabel = `<span class="silence">${silenceIcon}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
} else if (item.speaker == 0 && !isFinalizing) {
|
||||
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'><span class="loading-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span> second(s) of audio are undergoing diarization</span></span>`;
|
||||
} else if (item.speaker !== 0) {
|
||||
const speakerNum = `<span class="speaker-badge">${item.speaker}</span>`;
|
||||
speakerLabel = `<span id="speaker">${speakerIcon}${speakerNum}<span id='timeInfo'>${timeInfo}</span></span>`;
|
||||
|
||||
if (item.detected_language) {
|
||||
speakerLabel += `<span class="label_language">${languageIcon}<span>${item.detected_language}</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
let currentLineText = item.text || "";
|
||||
|
||||
if (idx === lines.length - 1) {
|
||||
if (!isFinalizing && item.speaker !== -2) {
|
||||
speakerLabel += `<span class="label_transcription"><span class="spinner"></span>Transcription lag <span id='timeInfo'><span class="lag-transcription-value">${fmt1(
|
||||
remaining_time_transcription
|
||||
)}</span>s</span></span>`;
|
||||
|
||||
if (buffer_diarization && remaining_time_diarization) {
|
||||
speakerLabel += `<span class="label_diarization"><span class="spinner"></span>Diarization lag<span id='timeInfo'><span class="lag-diarization-value">${fmt1(
|
||||
remaining_time_diarization
|
||||
)}</span>s</span></span>`;
|
||||
}
|
||||
}
|
||||
|
||||
if (buffer_diarization) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_diarization.trim().length > 0 ? " " : "") + buffer_diarization.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_diarization">${buffer_diarization}</span>`;
|
||||
}
|
||||
}
|
||||
if (buffer_transcription) {
|
||||
if (isFinalizing) {
|
||||
currentLineText +=
|
||||
(currentLineText.length > 0 && buffer_transcription.trim().length > 0 ? " " : "") +
|
||||
buffer_transcription.trim();
|
||||
} else {
|
||||
currentLineText += `<span class="buffer_transcription">${buffer_transcription}</span>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
let translationContent = "";
|
||||
if (item.translation) {
|
||||
translationContent += item.translation.trim();
|
||||
}
|
||||
if (idx === lines.length - 1 && buffer_translation) {
|
||||
const bufferPiece = isFinalizing
|
||||
? buffer_translation
|
||||
: `<span class="buffer_translation">${buffer_translation}</span>`;
|
||||
translationContent += translationContent ? `${bufferPiece}` : bufferPiece;
|
||||
}
|
||||
if (translationContent.trim().length > 0) {
|
||||
currentLineText += `
|
||||
<div>
|
||||
<div class="label_translation">
|
||||
${translationIcon}
|
||||
<span class="translation_text">${translationContent}</span>
|
||||
</div>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
return currentLineText.trim().length > 0 || speakerLabel.length > 0
|
||||
? `<p>${speakerLabel}<br/><div class='textcontent'>${currentLineText}</div></p>`
|
||||
: `<p>${speakerLabel}<br/></p>`;
|
||||
})
|
||||
.join("");
|
||||
|
||||
linesTranscriptDiv.innerHTML = linesHtml;
|
||||
const transcriptContainer = document.querySelector('.transcript-container');
|
||||
if (transcriptContainer) {
|
||||
transcriptContainer.scrollTo({ top: transcriptContainer.scrollHeight, behavior: "smooth" });
|
||||
}
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
if (!startTime) return;
|
||||
|
||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||
const minutes = Math.floor(elapsed / 60).toString().padStart(2, "0");
|
||||
const seconds = (elapsed % 60).toString().padStart(2, "0");
|
||||
timerElement.textContent = `${minutes}:${seconds}`;
|
||||
}
|
||||
|
||||
function drawWaveform() {
|
||||
if (!analyser) return;
|
||||
|
||||
const bufferLength = analyser.frequencyBinCount;
|
||||
const dataArray = new Uint8Array(bufferLength);
|
||||
analyser.getByteTimeDomainData(dataArray);
|
||||
|
||||
waveCtx.clearRect(
|
||||
0,
|
||||
0,
|
||||
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||
waveCanvas.height / (window.devicePixelRatio || 1)
|
||||
);
|
||||
waveCtx.lineWidth = 1;
|
||||
waveCtx.strokeStyle = waveStroke;
|
||||
waveCtx.beginPath();
|
||||
|
||||
const sliceWidth = (waveCanvas.width / (window.devicePixelRatio || 1)) / bufferLength;
|
||||
let x = 0;
|
||||
|
||||
for (let i = 0; i < bufferLength; i++) {
|
||||
const v = dataArray[i] / 128.0;
|
||||
const y = (v * (waveCanvas.height / (window.devicePixelRatio || 1))) / 2;
|
||||
|
||||
if (i === 0) {
|
||||
waveCtx.moveTo(x, y);
|
||||
} else {
|
||||
waveCtx.lineTo(x, y);
|
||||
}
|
||||
|
||||
x += sliceWidth;
|
||||
}
|
||||
|
||||
waveCtx.lineTo(
|
||||
waveCanvas.width / (window.devicePixelRatio || 1),
|
||||
(waveCanvas.height / (window.devicePixelRatio || 1)) / 2
|
||||
);
|
||||
waveCtx.stroke();
|
||||
|
||||
animationFrame = requestAnimationFrame(drawWaveform);
|
||||
}
|
||||
|
||||
async function startRecording() {
|
||||
try {
|
||||
try {
|
||||
wakeLock = await navigator.wakeLock.request("screen");
|
||||
} catch (err) {
|
||||
console.log("Error acquiring wake lock.");
|
||||
}
|
||||
|
||||
let stream;
|
||||
|
||||
// chromium extension. in the future, both chrome page audio and mic will be used
|
||||
if (isExtension) {
|
||||
try {
|
||||
stream = await new Promise((resolve, reject) => {
|
||||
chrome.tabCapture.capture({audio: true}, (s) => {
|
||||
if (s) {
|
||||
resolve(s);
|
||||
} else {
|
||||
reject(new Error('Tab capture failed or not available'));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
try {
|
||||
outputAudioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
audioSource = outputAudioContext.createMediaStreamSource(stream);
|
||||
audioSource.connect(outputAudioContext.destination);
|
||||
} catch (audioError) {
|
||||
console.warn('could not preserve system audio:', audioError);
|
||||
}
|
||||
|
||||
statusText.textContent = "Using tab audio capture.";
|
||||
} catch (tabError) {
|
||||
console.log('Tab capture not available, falling back to microphone', tabError);
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
statusText.textContent = "Using microphone audio.";
|
||||
}
|
||||
} else if (isWebContext) {
|
||||
const audioConstraints = selectedMicrophoneId
|
||||
? { audio: { deviceId: { exact: selectedMicrophoneId } } }
|
||||
: { audio: true };
|
||||
stream = await navigator.mediaDevices.getUserMedia(audioConstraints);
|
||||
}
|
||||
|
||||
audioContext = new (window.AudioContext || window.webkitAudioContext)();
|
||||
analyser = audioContext.createAnalyser();
|
||||
analyser.fftSize = 256;
|
||||
microphone = audioContext.createMediaStreamSource(stream);
|
||||
microphone.connect(analyser);
|
||||
|
||||
if (serverUseAudioWorklet) {
|
||||
if (!audioContext.audioWorklet) {
|
||||
throw new Error("AudioWorklet is not supported in this browser");
|
||||
}
|
||||
await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");
|
||||
workletNode = new AudioWorkletNode(audioContext, "pcm-forwarder", { numberOfInputs: 1, numberOfOutputs: 0, channelCount: 1 });
|
||||
microphone.connect(workletNode);
|
||||
|
||||
recorderWorker = new Worker("/web/recorder_worker.js");
|
||||
recorderWorker.postMessage({
|
||||
command: "init",
|
||||
config: {
|
||||
sampleRate: audioContext.sampleRate,
|
||||
},
|
||||
});
|
||||
|
||||
recorderWorker.onmessage = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
websocket.send(e.data.buffer);
|
||||
}
|
||||
};
|
||||
|
||||
workletNode.port.onmessage = (e) => {
|
||||
const data = e.data;
|
||||
const ab = data instanceof ArrayBuffer ? data : data.buffer;
|
||||
recorderWorker.postMessage(
|
||||
{
|
||||
command: "record",
|
||||
buffer: ab,
|
||||
},
|
||||
[ab]
|
||||
);
|
||||
};
|
||||
} else {
|
||||
try {
|
||||
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
|
||||
} catch (e) {
|
||||
recorder = new MediaRecorder(stream);
|
||||
}
|
||||
recorder.ondataavailable = (e) => {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
if (e.data && e.data.size > 0) {
|
||||
websocket.send(e.data);
|
||||
}
|
||||
}
|
||||
};
|
||||
recorder.start(chunkDuration);
|
||||
}
|
||||
|
||||
startTime = Date.now();
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
drawWaveform();
|
||||
|
||||
isRecording = true;
|
||||
updateUI();
|
||||
} catch (err) {
|
||||
if (window.location.hostname === "0.0.0.0") {
|
||||
statusText.textContent =
|
||||
"Error accessing microphone. Browsers may block microphone access on 0.0.0.0. Try using localhost:8000 instead.";
|
||||
} else {
|
||||
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
|
||||
}
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
|
||||
async function stopRecording() {
|
||||
if (wakeLock) {
|
||||
try {
|
||||
await wakeLock.release();
|
||||
} catch (e) {
|
||||
// ignore
|
||||
}
|
||||
wakeLock = null;
|
||||
}
|
||||
|
||||
userClosing = true;
|
||||
waitingForStop = true;
|
||||
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
const emptyBlob = new Blob([], { type: "audio/webm" });
|
||||
websocket.send(emptyBlob);
|
||||
statusText.textContent = "Recording stopped. Processing final audio...";
|
||||
}
|
||||
|
||||
if (recorder) {
|
||||
try {
|
||||
recorder.stop();
|
||||
} catch (e) {
|
||||
}
|
||||
recorder = null;
|
||||
}
|
||||
|
||||
if (recorderWorker) {
|
||||
recorderWorker.terminate();
|
||||
recorderWorker = null;
|
||||
}
|
||||
|
||||
if (workletNode) {
|
||||
try {
|
||||
workletNode.port.onmessage = null;
|
||||
} catch (e) {}
|
||||
try {
|
||||
workletNode.disconnect();
|
||||
} catch (e) {}
|
||||
workletNode = null;
|
||||
}
|
||||
|
||||
if (microphone) {
|
||||
microphone.disconnect();
|
||||
microphone = null;
|
||||
}
|
||||
|
||||
if (analyser) {
|
||||
analyser = null;
|
||||
}
|
||||
|
||||
if (audioContext && audioContext.state !== "closed") {
|
||||
try {
|
||||
await audioContext.close();
|
||||
} catch (e) {
|
||||
console.warn("Could not close audio context:", e);
|
||||
}
|
||||
audioContext = null;
|
||||
}
|
||||
|
||||
if (audioSource) {
|
||||
audioSource.disconnect();
|
||||
audioSource = null;
|
||||
}
|
||||
|
||||
if (outputAudioContext && outputAudioContext.state !== "closed") {
|
||||
outputAudioContext.close()
|
||||
outputAudioContext = null;
|
||||
}
|
||||
|
||||
if (animationFrame) {
|
||||
cancelAnimationFrame(animationFrame);
|
||||
animationFrame = null;
|
||||
}
|
||||
|
||||
if (timerInterval) {
|
||||
clearInterval(timerInterval);
|
||||
timerInterval = null;
|
||||
}
|
||||
timerElement.textContent = "00:00";
|
||||
startTime = null;
|
||||
|
||||
isRecording = false;
|
||||
updateUI();
|
||||
}
|
||||
|
||||
async function toggleRecording() {
|
||||
if (!isRecording) {
|
||||
if (waitingForStop) {
|
||||
console.log("Waiting for stop, early return");
|
||||
return;
|
||||
}
|
||||
console.log("Connecting to WebSocket");
|
||||
try {
|
||||
if (websocket && websocket.readyState === WebSocket.OPEN) {
|
||||
await configReady;
|
||||
await startRecording();
|
||||
} else {
|
||||
await setupWebSocket();
|
||||
await configReady;
|
||||
await startRecording();
|
||||
}
|
||||
} catch (err) {
|
||||
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
|
||||
console.error(err);
|
||||
}
|
||||
} else {
|
||||
console.log("Stopping recording");
|
||||
stopRecording();
|
||||
}
|
||||
}
|
||||
|
||||
function updateUI() {
|
||||
recordButton.classList.toggle("recording", isRecording);
|
||||
recordButton.disabled = waitingForStop;
|
||||
|
||||
if (waitingForStop) {
|
||||
if (statusText.textContent !== "Recording stopped. Processing final audio...") {
|
||||
statusText.textContent = "Please wait for processing to complete...";
|
||||
}
|
||||
} else if (isRecording) {
|
||||
statusText.textContent = "";
|
||||
} else {
|
||||
if (
|
||||
statusText.textContent !== "Finished processing audio! Ready to record again." &&
|
||||
statusText.textContent !== "Processing finalized or connection closed."
|
||||
) {
|
||||
statusText.textContent = "Click to start transcription";
|
||||
}
|
||||
}
|
||||
if (!waitingForStop) {
|
||||
recordButton.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
recordButton.addEventListener("click", toggleRecording);
|
||||
|
||||
if (microphoneSelect) {
|
||||
microphoneSelect.addEventListener("change", handleMicrophoneChange);
|
||||
}
|
||||
document.addEventListener('DOMContentLoaded', async () => {
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Could not enumerate microphones on load:", error);
|
||||
}
|
||||
});
|
||||
navigator.mediaDevices.addEventListener('devicechange', async () => {
|
||||
console.log('Device change detected, re-enumerating microphones');
|
||||
try {
|
||||
await enumerateMicrophones();
|
||||
} catch (error) {
|
||||
console.log("Error re-enumerating microphones:", error);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
settingsToggle.addEventListener("click", () => {
|
||||
settingsDiv.classList.toggle("visible");
|
||||
settingsToggle.classList.toggle("active");
|
||||
});
|
||||
|
||||
if (isExtension) {
|
||||
async function checkAndRequestPermissions() {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
|
||||
const permissionDisplay = document.getElementById("audioPermission");
|
||||
if (permissionDisplay) {
|
||||
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
}
|
||||
|
||||
// if (micPermission.state !== "granted") {
|
||||
// chrome.tabs.create({ url: "welcome.html" });
|
||||
// }
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
const micPermission = await navigator.permissions.query({
|
||||
name: "microphone",
|
||||
});
|
||||
if (micPermission.state === "granted") {
|
||||
if (permissionDisplay) {
|
||||
permissionDisplay.innerText = `MICROPHONE: ${micPermission.state}`;
|
||||
}
|
||||
clearInterval(intervalId);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
void checkAndRequestPermissions();
|
||||
}
|
||||
16
whisperlivekit/web/pcm_worklet.js
Normal file
@@ -0,0 +1,16 @@
|
||||
class PCMForwarder extends AudioWorkletProcessor {
|
||||
process(inputs) {
|
||||
const input = inputs[0];
|
||||
if (input && input[0] && input[0].length) {
|
||||
// Forward mono channel (0). If multi-channel, downmixing can be added here.
|
||||
const channelData = input[0];
|
||||
const copy = new Float32Array(channelData.length);
|
||||
copy.set(channelData);
|
||||
this.port.postMessage(copy, [copy.buffer]);
|
||||
}
|
||||
// Keep processor alive
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
registerProcessor('pcm-forwarder', PCMForwarder);
|
||||
58
whisperlivekit/web/recorder_worker.js
Normal file
@@ -0,0 +1,58 @@
|
||||
let sampleRate = 48000;
|
||||
let targetSampleRate = 16000;
|
||||
|
||||
self.onmessage = function (e) {
|
||||
switch (e.data.command) {
|
||||
case 'init':
|
||||
init(e.data.config);
|
||||
break;
|
||||
case 'record':
|
||||
record(e.data.buffer);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
function init(config) {
|
||||
sampleRate = config.sampleRate;
|
||||
targetSampleRate = config.targetSampleRate || 16000;
|
||||
}
|
||||
|
||||
function record(inputBuffer) {
|
||||
const buffer = new Float32Array(inputBuffer);
|
||||
const resampledBuffer = resample(buffer, sampleRate, targetSampleRate);
|
||||
const pcmBuffer = toPCM(resampledBuffer);
|
||||
self.postMessage({ buffer: pcmBuffer }, [pcmBuffer]);
|
||||
}
|
||||
|
||||
function resample(buffer, from, to) {
|
||||
if (from === to) {
|
||||
return buffer;
|
||||
}
|
||||
const ratio = from / to;
|
||||
const newLength = Math.round(buffer.length / ratio);
|
||||
const result = new Float32Array(newLength);
|
||||
let offsetResult = 0;
|
||||
let offsetBuffer = 0;
|
||||
while (offsetResult < result.length) {
|
||||
const nextOffsetBuffer = Math.round((offsetResult + 1) * ratio);
|
||||
let accum = 0, count = 0;
|
||||
for (let i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) {
|
||||
accum += buffer[i];
|
||||
count++;
|
||||
}
|
||||
result[offsetResult] = accum / count;
|
||||
offsetResult++;
|
||||
offsetBuffer = nextOffsetBuffer;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
function toPCM(input) {
|
||||
const buffer = new ArrayBuffer(input.length * 2);
|
||||
const view = new DataView(buffer);
|
||||
for (let i = 0; i < input.length; i++) {
|
||||
const s = Math.max(-1, Math.min(1, input[i]));
|
||||
view.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
1
whisperlivekit/web/src/dark_mode.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-120q-151 0-255.5-104.5T120-480q0-138 90-239.5T440-838q13-2 23 3.5t16 14.5q6 9 6.5 21t-7.5 23q-17 26-25.5 55t-8.5 61q0 90 63 153t153 63q31 0 61.5-9t54.5-25q11-7 22.5-6.5T819-479q10 5 15.5 15t3.5 24q-14 138-117.5 229T480-120Zm0-80q88 0 158-48.5T740-375q-20 5-40 8t-40 3q-123 0-209.5-86.5T364-660q0-20 3-40t8-40q-78 32-126.5 102T200-480q0 116 82 198t198 82Zm-10-270Z"/></svg>
|
||||
|
After Width: | Height: | Size: 493 B |
1
whisperlivekit/web/src/language.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-80q-82 0-155-31.5t-127.5-86Q143-252 111.5-325T80-480q0-83 31.5-155.5t86-127Q252-817 325-848.5T480-880q83 0 155.5 31.5t127 86q54.5 54.5 86 127T880-480q0 82-31.5 155t-86 127.5q-54.5 54.5-127 86T480-80Zm0-82q26-36 45-75t31-83H404q12 44 31 83t45 75Zm-104-16q-18-33-31.5-68.5T322-320H204q29 50 72.5 87t99.5 55Zm208 0q56-18 99.5-55t72.5-87H638q-9 38-22.5 73.5T584-178ZM170-400h136q-3-20-4.5-39.5T300-480q0-21 1.5-40.5T306-560H170q-5 20-7.5 39.5T160-480q0 21 2.5 40.5T170-400Zm216 0h188q3-20 4.5-39.5T580-480q0-21-1.5-40.5T574-560H386q-3 20-4.5 39.5T380-480q0 21 1.5 40.5T386-400Zm268 0h136q5-20 7.5-39.5T800-480q0-21-2.5-40.5T790-560H654q3 20 4.5 39.5T660-480q0 21-1.5 40.5T654-400Zm-16-240h118q-29-50-72.5-87T584-782q18 33 31.5 68.5T638-640Zm-234 0h152q-12-44-31-83t-45-75q-26 36-45 75t-31 83Zm-200 0h118q9-38 22.5-73.5T376-782q-56 18-99.5 55T204-640Z"/></svg>
|
||||
|
After Width: | Height: | Size: 976 B |
1
whisperlivekit/web/src/light_mode.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-360q50 0 85-35t35-85q0-50-35-85t-85-35q-50 0-85 35t-35 85q0 50 35 85t85 35Zm0 80q-83 0-141.5-58.5T280-480q0-83 58.5-141.5T480-680q83 0 141.5 58.5T680-480q0 83-58.5 141.5T480-280ZM80-440q-17 0-28.5-11.5T40-480q0-17 11.5-28.5T80-520h80q17 0 28.5 11.5T200-480q0 17-11.5 28.5T160-440H80Zm720 0q-17 0-28.5-11.5T760-480q0-17 11.5-28.5T800-520h80q17 0 28.5 11.5T920-480q0 17-11.5 28.5T880-440h-80ZM480-760q-17 0-28.5-11.5T440-800v-80q0-17 11.5-28.5T480-920q17 0 28.5 11.5T520-880v80q0 17-11.5 28.5T480-760Zm0 720q-17 0-28.5-11.5T440-80v-80q0-17 11.5-28.5T480-200q17 0 28.5 11.5T520-160v80q0 17-11.5 28.5T480-40ZM226-678l-43-42q-12-11-11.5-28t11.5-29q12-12 29-12t28 12l42 43q11 12 11 28t-11 28q-11 12-27.5 11.5T226-678Zm494 495-42-43q-11-12-11-28.5t11-27.5q11-12 27.5-11.5T734-282l43 42q12 11 11.5 28T777-183q-12 12-29 12t-28-12Zm-42-495q-12-11-11.5-27.5T678-734l42-43q11-12 28-11.5t29 11.5q12 12 12 29t-12 28l-43 42q-12 11-28 11t-28-11ZM183-183q-12-12-12-29t12-28l43-42q12-11 28.5-11t27.5 11q12 11 11.5 27.5T282-226l-42 43q-11 12-28 11.5T183-183Zm297-297Z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.2 KiB |
1
whisperlivekit/web/src/settings.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M433-80q-27 0-46.5-18T363-142l-9-66q-13-5-24.5-12T307-235l-62 26q-25 11-50 2t-39-32l-47-82q-14-23-8-49t27-43l53-40q-1-7-1-13.5v-27q0-6.5 1-13.5l-53-40q-21-17-27-43t8-49l47-82q14-23 39-32t50 2l62 26q11-8 23-15t24-12l9-66q4-26 23.5-44t46.5-18h94q27 0 46.5 18t23.5 44l9 66q13 5 24.5 12t22.5 15l62-26q25-11 50-2t39 32l47 82q14 23 8 49t-27 43l-53 40q1 7 1 13.5v27q0 6.5-2 13.5l53 40q21 17 27 43t-8 49l-48 82q-14 23-39 32t-50-2l-60-26q-11 8-23 15t-24 12l-9 66q-4 26-23.5 44T527-80h-94Zm7-80h79l14-106q31-8 57.5-23.5T639-327l99 41 39-68-86-65q5-14 7-29.5t2-31.5q0-16-2-31.5t-7-29.5l86-65-39-68-99 42q-22-23-48.5-38.5T533-694l-13-106h-79l-14 106q-31 8-57.5 23.5T321-633l-99-41-39 68 86 64q-5 15-7 30t-2 32q0 16 2 31t7 30l-86 65 39 68 99-42q22 23 48.5 38.5T427-266l13 106Zm42-180q58 0 99-41t41-99q0-58-41-99t-99-41q-59 0-99.5 41T342-480q0 58 40.5 99t99.5 41Zm-2-140Z"/></svg>
|
||||
|
After Width: | Height: | Size: 982 B |
1
whisperlivekit/web/src/silence.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M514-556 320-752q9-3 19-5.5t21-2.5q66 0 113 47t47 113q0 11-1.5 22t-4.5 22ZM40-200v-32q0-33 17-62t47-44q51-26 115-44t141-18q26 0 49.5 2.5T456-392l-56-54q-9 3-19 4.5t-21 1.5q-66 0-113-47t-47-113q0-11 1.5-21t4.5-19L84-764q-11-11-11-28t11-28q12-12 28.5-12t27.5 12l675 685q11 11 11.5 27.5T816-80q-11 13-28 12.5T759-80L641-200h39q0 33-23.5 56.5T600-120H120q-33 0-56.5-23.5T40-200Zm80 0h480v-32q0-14-4.5-19.5T580-266q-36-18-92.5-36T360-320q-71 0-127.5 18T140-266q-9 5-14.5 14t-5.5 20v32Zm240 0Zm560-400q0 69-24.5 131.5T829-355q-12 14-30 15t-32-13q-13-13-12-31t12-33q30-38 46.5-85t16.5-98q0-51-16.5-97T767-781q-12-15-12.5-33t12.5-32q13-14 31.5-13.5T829-845q42 51 66.5 113.5T920-600Zm-182 0q0 32-10 61.5T700-484q-11 15-29.5 15.5T638-482q-13-13-13.5-31.5T633-549q6-11 9.5-24t3.5-27q0-14-3.5-27t-9.5-25q-9-17-8.5-35t13.5-31q14-14 32.5-13.5T700-716q18 25 28 54.5t10 61.5Z"/></svg>
|
||||
|
After Width: | Height: | Size: 984 B |
1
whisperlivekit/web/src/speaker.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M480-480q-66 0-113-47t-47-113q0-66 47-113t113-47q66 0 113 47t47 113q0 66-47 113t-113 47ZM160-240v-32q0-34 17.5-62.5T224-378q62-31 126-46.5T480-440q66 0 130 15.5T736-378q29 15 46.5 43.5T800-272v32q0 33-23.5 56.5T720-160H240q-33 0-56.5-23.5T160-240Zm80 0h480v-32q0-11-5.5-20T700-306q-54-27-109-40.5T480-360q-56 0-111 13.5T260-306q-9 5-14.5 14t-5.5 20v32Zm240-320q33 0 56.5-23.5T560-640q0-33-23.5-56.5T480-720q-33 0-56.5 23.5T400-640q0 33 23.5 56.5T480-560Zm0-80Zm0 400Z"/></svg>
|
||||
|
After Width: | Height: | Size: 592 B |
1
whisperlivekit/web/src/system_mode.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="M396-396q-32-32-58.5-67T289-537q-5 14-6.5 28.5T281-480q0 83 58 141t141 58q14 0 28.5-2t28.5-6q-39-22-74-48.5T396-396Zm85 196q-56 0-107-21t-91-61q-40-40-61-91t-21-107q0-51 17-97.5t50-84.5q13-14 32-9.5t27 24.5q21 55 52.5 104t73.5 91q42 42 91 73.5T648-326q20 8 24.5 27t-9.5 32q-38 33-84.5 50T481-200Zm223-192q-16-5-23-20.5t-4-32.5q9-48-6-94.5T621-621q-35-35-80.5-49.5T448-677q-17 3-32-4t-21-23q-6-16 1.5-31t23.5-19q69-15 138 4.5T679-678q51 51 71 120t5 138q-4 17-19 25t-32 3ZM480-840q-17 0-28.5-11.5T440-880v-40q0-17 11.5-28.5T480-960q17 0 28.5 11.5T520-920v40q0 17-11.5 28.5T480-840Zm0 840q-17 0-28.5-11.5T440-40v-40q0-17 11.5-28.5T480-120q17 0 28.5 11.5T520-80v40q0 17-11.5 28.5T480 0Zm255-734q-12-12-12-28.5t12-28.5l28-28q11-11 27.5-11t28.5 11q12 12 12 28.5T819-762l-28 28q-12 12-28 12t-28-12ZM141-141q-12-12-12-28.5t12-28.5l28-28q12-12 28-12t28 12q12 12 12 28.5T225-169l-28 28q-11 11-27.5 11T141-141Zm739-299q-17 0-28.5-11.5T840-480q0-17 11.5-28.5T880-520h40q17 0 28.5 11.5T960-480q0 17-11.5 28.5T920-440h-40Zm-840 0q-17 0-28.5-11.5T0-480q0-17 11.5-28.5T40-520h40q17 0 28.5 11.5T120-480q0 17-11.5 28.5T80-440H40Zm779 299q-12 12-28.5 12T762-141l-28-28q-12-12-12-28t12-28q12-12 28.5-12t28.5 12l28 28q11 11 11 27.5T819-141ZM226-735q-12 12-28.5 12T169-735l-28-28q-11-11-11-27.5t11-28.5q12-12 28.5-12t28.5 12l28 28q12 12 12 28t-12 28Zm170 339Z"/></svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
1
whisperlivekit/web/src/translate.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#5f6368"><path d="m603-202-34 97q-4 11-14 18t-22 7q-20 0-32.5-16.5T496-133l152-402q5-11 15-18t22-7h30q12 0 22 7t15 18l152 403q8 19-4 35.5T868-80q-13 0-22.5-7T831-106l-34-96H603ZM362-401 188-228q-11 11-27.5 11.5T132-228q-11-11-11-28t11-28l174-174q-35-35-63.5-80T190-640h84q20 39 40 68t48 58q33-33 68.5-92.5T484-720H80q-17 0-28.5-11.5T40-760q0-17 11.5-28.5T80-800h240v-40q0-17 11.5-28.5T360-880q17 0 28.5 11.5T400-840v40h240q17 0 28.5 11.5T680-760q0 17-11.5 28.5T640-720h-76q-21 72-63 148t-83 116l96 98-30 82-122-125Zm266 129h144l-72-204-72 204Z"/></svg>
|
||||
|
After Width: | Height: | Size: 650 B |
114
whisperlivekit/web/web_interface.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
import importlib.resources as resources
|
||||
import base64
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_web_interface_html():
|
||||
"""Loads the HTML for the web interface using importlib.resources."""
|
||||
try:
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading web interface HTML: {e}")
|
||||
return "<html><body><h1>Error loading interface</h1></body></html>"
|
||||
|
||||
def get_inline_ui_html():
|
||||
"""Returns the complete web interface HTML with all assets embedded in a single call."""
|
||||
try:
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.html').open('r', encoding='utf-8') as f:
|
||||
html_content = f.read()
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.css').open('r', encoding='utf-8') as f:
|
||||
css_content = f.read()
|
||||
with resources.files('whisperlivekit.web').joinpath('live_transcription.js').open('r', encoding='utf-8') as f:
|
||||
js_content = f.read()
|
||||
|
||||
with resources.files('whisperlivekit.web').joinpath('pcm_worklet.js').open('r', encoding='utf-8') as f:
|
||||
worklet_code = f.read()
|
||||
with resources.files('whisperlivekit.web').joinpath('recorder_worker.js').open('r', encoding='utf-8') as f:
|
||||
worker_code = f.read()
|
||||
|
||||
js_content = js_content.replace(
|
||||
'await audioContext.audioWorklet.addModule("/web/pcm_worklet.js");',
|
||||
'const workletBlob = new Blob([`' + worklet_code + '`], { type: "application/javascript" });\n' +
|
||||
'const workletUrl = URL.createObjectURL(workletBlob);\n' +
|
||||
'await audioContext.audioWorklet.addModule(workletUrl);'
|
||||
)
|
||||
js_content = js_content.replace(
|
||||
'recorderWorker = new Worker("/web/recorder_worker.js");',
|
||||
'const workerBlob = new Blob([`' + worker_code + '`], { type: "application/javascript" });\n' +
|
||||
'const workerUrl = URL.createObjectURL(workerBlob);\n' +
|
||||
'recorderWorker = new Worker(workerUrl);'
|
||||
)
|
||||
|
||||
# SVG files
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'system_mode.svg').open('r', encoding='utf-8') as f:
|
||||
system_svg = f.read()
|
||||
system_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(system_svg.encode('utf-8')).decode('utf-8')}"
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'light_mode.svg').open('r', encoding='utf-8') as f:
|
||||
light_svg = f.read()
|
||||
light_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(light_svg.encode('utf-8')).decode('utf-8')}"
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'dark_mode.svg').open('r', encoding='utf-8') as f:
|
||||
dark_svg = f.read()
|
||||
dark_data_uri = f"data:image/svg+xml;base64,{base64.b64encode(dark_svg.encode('utf-8')).decode('utf-8')}"
|
||||
with resources.files('whisperlivekit.web').joinpath('src', 'settings.svg').open('r', encoding='utf-8') as f:
|
||||
settings = f.read()
|
||||
settings_uri = f"data:image/svg+xml;base64,{base64.b64encode(settings.encode('utf-8')).decode('utf-8')}"
|
||||
|
||||
# Replace external references
|
||||
html_content = html_content.replace(
|
||||
'<link rel="stylesheet" href="live_transcription.css" />',
|
||||
f'<style>\n{css_content}\n</style>'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<script src="live_transcription.js"></script>',
|
||||
f'<script>\n{js_content}\n</script>'
|
||||
)
|
||||
|
||||
# Replace SVG references
|
||||
html_content = html_content.replace(
|
||||
'<img src="/web/src/system_mode.svg" alt="" />',
|
||||
f'<img src="{system_data_uri}" alt="" />'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<img src="/web/src/light_mode.svg" alt="" />',
|
||||
f'<img src="{light_data_uri}" alt="" />'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<img src="/web/src/dark_mode.svg" alt="" />',
|
||||
f'<img src="{dark_data_uri}" alt="" />'
|
||||
)
|
||||
|
||||
html_content = html_content.replace(
|
||||
'<img src="web/src/settings.svg" alt="Settings" />',
|
||||
f'<img src="{settings_uri}" alt="" />'
|
||||
)
|
||||
|
||||
return html_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating embedded web interface: {e}")
|
||||
return "<html><body><h1>Error loading embedded interface</h1></body></html>"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
import uvicorn
|
||||
from starlette.staticfiles import StaticFiles
|
||||
import pathlib
|
||||
import whisperlivekit.web as webpkg
|
||||
|
||||
app = FastAPI()
|
||||
web_dir = pathlib.Path(webpkg.__file__).parent
|
||||
app.mount("/web", StaticFiles(directory=str(web_dir)), name="web")
|
||||
|
||||
@app.get("/")
|
||||
async def get():
|
||||
return HTMLResponse(get_inline_ui_html())
|
||||
|
||||
uvicorn.run(app=app)
|
||||
463
whisperlivekit/whisper/__init__.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from torch import Tensor
|
||||
|
||||
from whisperlivekit.whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from whisperlivekit.whisper.decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from whisperlivekit.whisper.model import ModelDimensions, Whisper
|
||||
from whisperlivekit.whisper.transcribe import transcribe
|
||||
from whisperlivekit.whisper.version import __version__
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
_ALIGNMENT_HEADS = {
|
||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def _infer_dims_from_config(path: str) -> Optional[ModelDimensions]:
|
||||
"""
|
||||
attempt to infer ModelDimensions from a HF style config.json located
|
||||
next to the given checkpoint, usefull for distilled models
|
||||
"""
|
||||
candidates = []
|
||||
if os.path.isdir(path):
|
||||
candidates.append(os.path.join(path, "config.json"))
|
||||
else:
|
||||
candidates.append(os.path.join(os.path.dirname(path), "config.json"))
|
||||
|
||||
for candidate in candidates:
|
||||
if not os.path.isfile(candidate):
|
||||
continue
|
||||
with open(candidate, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
try:
|
||||
return ModelDimensions(
|
||||
n_mels=config["num_mel_bins"],
|
||||
n_audio_ctx=config["max_source_positions"],
|
||||
n_audio_state=config["d_model"],
|
||||
n_audio_head=config["encoder_attention_heads"],
|
||||
n_audio_layer=config.get("encoder_layers")
|
||||
or config["num_hidden_layers"],
|
||||
n_vocab=config["vocab_size"],
|
||||
n_text_ctx=config["max_target_positions"],
|
||||
n_text_state=config["d_model"],
|
||||
n_text_head=config["decoder_attention_heads"],
|
||||
n_text_layer=config["decoder_layers"],
|
||||
)
|
||||
except KeyError as err:
|
||||
warnings.warn(f"Missing key {err} in HuggingFace config {candidate}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _convert_hf_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
converts a HF checkpoint state_dict into the naming convention used by
|
||||
default whisper
|
||||
"""
|
||||
|
||||
if not any(k.startswith("model.") for k in state_dict):
|
||||
return state_dict
|
||||
|
||||
def map_block(prefix: str, target_prefix: str, remainder: str) -> Optional[str]:
|
||||
if remainder.startswith("self_attn."):
|
||||
suffix = remainder.split(".", 1)[1]
|
||||
mapping = {
|
||||
"q_proj": "attn.query",
|
||||
"k_proj": "attn.key",
|
||||
"v_proj": "attn.value",
|
||||
"out_proj": "attn.out",
|
||||
}
|
||||
stem = mapping.get(suffix.split(".")[0])
|
||||
if stem:
|
||||
rest = suffix.split(".", 1)[1] if "." in suffix else ""
|
||||
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
|
||||
elif remainder == "self_attn_layer_norm.weight":
|
||||
return f"{target_prefix}.attn_ln.weight"
|
||||
elif remainder == "self_attn_layer_norm.bias":
|
||||
return f"{target_prefix}.attn_ln.bias"
|
||||
elif remainder.startswith("encoder_attn."):
|
||||
suffix = remainder.split(".", 1)[1]
|
||||
mapping = {
|
||||
"q_proj": "cross_attn.query",
|
||||
"k_proj": "cross_attn.key",
|
||||
"v_proj": "cross_attn.value",
|
||||
"out_proj": "cross_attn.out",
|
||||
}
|
||||
stem = mapping.get(suffix.split(".", 1)[0])
|
||||
if stem:
|
||||
rest = suffix.split(".", 1)[1] if "." in suffix else ""
|
||||
return f"{target_prefix}.{stem}" + (f".{rest}" if rest else "")
|
||||
elif remainder == "encoder_attn_layer_norm.weight":
|
||||
return f"{target_prefix}.cross_attn_ln.weight"
|
||||
elif remainder == "encoder_attn_layer_norm.bias":
|
||||
return f"{target_prefix}.cross_attn_ln.bias"
|
||||
elif remainder.startswith("fc1."):
|
||||
return f"{target_prefix}.mlp.0.{remainder.split('.',1)[1]}"
|
||||
elif remainder.startswith("fc2."):
|
||||
return f"{target_prefix}.mlp.2.{remainder.split('.',1)[1]}"
|
||||
elif remainder == "final_layer_norm.weight":
|
||||
return f"{target_prefix}.mlp_ln.weight"
|
||||
elif remainder == "final_layer_norm.bias":
|
||||
return f"{target_prefix}.mlp_ln.bias"
|
||||
return None
|
||||
|
||||
converted = {}
|
||||
for key, value in state_dict.items():
|
||||
if not key.startswith("model."):
|
||||
continue
|
||||
subkey = key[len("model.") :]
|
||||
|
||||
if subkey.startswith("encoder.layers."):
|
||||
parts = subkey.split(".")
|
||||
layer_idx = parts[2]
|
||||
remainder = ".".join(parts[3:])
|
||||
mapped = map_block(subkey, f"encoder.blocks.{layer_idx}", remainder)
|
||||
elif subkey.startswith("decoder.layers."):
|
||||
parts = subkey.split(".")
|
||||
layer_idx = parts[2]
|
||||
remainder = ".".join(parts[3:])
|
||||
mapped = map_block(subkey, f"decoder.blocks.{layer_idx}", remainder)
|
||||
elif subkey.startswith("encoder.conv") or subkey.startswith("decoder.conv"):
|
||||
mapped = subkey
|
||||
elif subkey == "encoder.embed_positions.weight":
|
||||
mapped = "encoder.positional_embedding"
|
||||
elif subkey == "decoder.embed_positions.weight":
|
||||
mapped = "decoder.positional_embedding"
|
||||
elif subkey == "encoder.layer_norm.weight":
|
||||
mapped = "encoder.ln_post.weight"
|
||||
elif subkey == "encoder.layer_norm.bias":
|
||||
mapped = "encoder.ln_post.bias"
|
||||
elif subkey.startswith("decoder.embed_tokens."):
|
||||
mapped = subkey.replace("embed_tokens", "token_embedding", 1)
|
||||
elif subkey == "decoder.layer_norm.weight":
|
||||
mapped = "decoder.ln.weight"
|
||||
elif subkey == "decoder.layer_norm.bias":
|
||||
mapped = "decoder.ln.bias"
|
||||
else:
|
||||
mapped = None
|
||||
|
||||
if mapped:
|
||||
converted[mapped] = value
|
||||
|
||||
return converted if converted else state_dict
|
||||
|
||||
|
||||
def _load_lora_state(lora_path: str):
|
||||
safe_path = os.path.join(lora_path, "adapter_model.safetensors")
|
||||
bin_path = os.path.join(lora_path, "adapter_model.bin")
|
||||
if os.path.isfile(safe_path):
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Loading LoRA adapters stored as .safetensors requires the `safetensors` package."
|
||||
) from exc
|
||||
return load_file(safe_path)
|
||||
if os.path.isfile(bin_path):
|
||||
return torch.load(bin_path, map_location="cpu")
|
||||
raise FileNotFoundError(
|
||||
f"No adapter weights found under {lora_path}. Expected adapter_model.safetensors or adapter_model.bin."
|
||||
)
|
||||
|
||||
|
||||
def _collapse_hf_module_name(module: str):
|
||||
if module.startswith("base_model."):
|
||||
module = module[len("base_model.") :]
|
||||
if module.startswith("model.model."):
|
||||
module = module[len("model.") :]
|
||||
if not module.startswith("model."):
|
||||
module = f"model.{module}"
|
||||
return module
|
||||
|
||||
|
||||
def _apply_lora_adapter(state_dict: Dict[str, Tensor], lora_path: Optional[str]):
|
||||
if not lora_path:
|
||||
return
|
||||
|
||||
config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
if not os.path.isfile(config_path):
|
||||
raise FileNotFoundError(f"Missing adapter_config.json inside {lora_path}")
|
||||
with open(config_path, "r", encoding="utf-8") as handle:
|
||||
config = json.load(handle)
|
||||
if config.get("peft_type") != "LORA":
|
||||
raise ValueError("Only LoRA adapters are supported.")
|
||||
|
||||
r = config.get("r")
|
||||
alpha = config.get("lora_alpha") or config.get("alpha")
|
||||
if not r or not alpha:
|
||||
raise ValueError("LoRA config must include `r` and `lora_alpha`.")
|
||||
scaling = alpha / r
|
||||
|
||||
adapter_state = _load_lora_state(lora_path)
|
||||
lora_layers: Dict[str, Dict[str, Tensor]] = {}
|
||||
for key, tensor in adapter_state.items():
|
||||
if key.endswith("lora_A.weight"):
|
||||
module = key[: -len(".lora_A.weight")]
|
||||
lora_layers.setdefault(module, {})["A"] = tensor
|
||||
elif key.endswith("lora_B.weight"):
|
||||
module = key[: -len(".lora_B.weight")]
|
||||
lora_layers.setdefault(module, {})["B"] = tensor
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError(f"No LoRA tensors found in {lora_path}")
|
||||
|
||||
for module, parts in lora_layers.items():
|
||||
if "A" not in parts or "B" not in parts:
|
||||
raise ValueError(f"Incomplete LoRA tensors for module '{module}'")
|
||||
|
||||
hf_module = _collapse_hf_module_name(module)
|
||||
hf_weight_key = f"{hf_module}.weight"
|
||||
|
||||
delta = parts["B"] @ parts["A"]
|
||||
delta = delta * scaling
|
||||
|
||||
converted = _convert_hf_state_dict({hf_weight_key: delta})
|
||||
if not converted:
|
||||
raise KeyError(f"Failed to map LoRA module '{module}' into Whisper state dict.")
|
||||
target_name, delta_tensor = next(iter(converted.items()))
|
||||
if target_name not in state_dict:
|
||||
raise KeyError(
|
||||
f"LoRA module '{module}' mapped to '{target_name}', but the base model has no such parameter."
|
||||
)
|
||||
|
||||
state_dict[target_name] = state_dict[target_name] + delta_tensor.to(
|
||||
dtype=state_dict[target_name].dtype, device=state_dict[target_name].device
|
||||
)
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
decoder_only: bool = False,
|
||||
custom_alignment_heads: Optional[str] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`, or
|
||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||
device : Union[str, torch.device]
|
||||
the PyTorch device to put the model into
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
in_memory: bool
|
||||
whether to preload the model weights into host memory
|
||||
lora_path: str
|
||||
optional directory containing PEFT LoRA adapter weights (adapter_config + adapter_model)
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
alignment_heads = _ALIGNMENT_HEADS.get(name, None)
|
||||
if custom_alignment_heads:
|
||||
alignment_heads = custom_alignment_heads.encode()
|
||||
|
||||
if isinstance(checkpoint_file, Path) and checkpoint_file.suffix == '.safetensors':
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
raise ImportError("Please install safetensors to load .safetensors model files: `pip install safetensors`")
|
||||
if in_memory:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
with (
|
||||
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims_cfg = checkpoint.get("dims") if isinstance(checkpoint, dict) else None
|
||||
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
state_dict = _convert_hf_state_dict(state_dict)
|
||||
_apply_lora_adapter(state_dict, lora_path)
|
||||
|
||||
if dims_cfg is not None:
|
||||
dims = ModelDimensions(**dims_cfg)
|
||||
else:
|
||||
dims = _infer_dims_from_config(name)
|
||||
if dims is None:
|
||||
raise RuntimeError(
|
||||
"Could not determine model dimensions. "
|
||||
"Ensure the checkpoint includes 'dims' or a HuggingFace config.json is present."
|
||||
)
|
||||
if not isinstance(state_dict, dict):
|
||||
state_dict = checkpoint
|
||||
|
||||
model = Whisper(dims, decoder_only=decoder_only)
|
||||
|
||||
if decoder_only:
|
||||
state_dict = {
|
||||
k: v for k, v in state_dict.items()
|
||||
if 'encoder' not in k
|
||||
}
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
||||
|
||||
|
||||
def convert_encoder_to_coreml(
|
||||
model_name = "base",
|
||||
output_path= "whisper_encoder.mlpackage",
|
||||
dummy_frames = 3000, #Number of time frames to use for the dummy mel input during tracing
|
||||
precision = "float16",
|
||||
):
|
||||
|
||||
import coremltools as ct
|
||||
model = load_model(model_name, device="cpu", decoder_only=False)
|
||||
encoder = model.encoder.eval().cpu()
|
||||
|
||||
dummy_input = torch.randn(
|
||||
1,
|
||||
model.dims.n_mels,
|
||||
dummy_frames,
|
||||
dtype=next(encoder.parameters()).dtype,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
traced_encoder = torch.jit.trace(encoder, dummy_input)
|
||||
|
||||
precision_map = {
|
||||
"float16": ct.precision.FLOAT16,
|
||||
"fp16": ct.precision.FLOAT16,
|
||||
"float32": ct.precision.FLOAT32,
|
||||
"fp32": ct.precision.FLOAT32,
|
||||
}
|
||||
coreml_precision = precision_map[precision.lower()]
|
||||
|
||||
mlmodel = ct.convert(
|
||||
traced_encoder,
|
||||
inputs=[ct.TensorType(name="mel", shape=dummy_input.shape)],
|
||||
convert_to= "mlprogram",
|
||||
compute_precision=coreml_precision,
|
||||
)
|
||||
|
||||
output_path = Path(output_path)
|
||||
mlmodel.save(str(output_path))
|
||||
return output_path
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# convert_encoder_to_coreml(model_name="tiny", output_path="whisper_encoder.mlpackage", dummy_frames=3000, precision="float16", convert_to="mlprogram")
|
||||
3
whisperlivekit/whisper/__main__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .transcribe import cli
|
||||
|
||||
cli()
|
||||
0
whisperlivekit/whisper/assets/__init__.py
Normal file
50256
whisperlivekit/whisper/assets/gpt2.tiktoken
Normal file
BIN
whisperlivekit/whisper/assets/mel_filters.npz
Normal file
50257
whisperlivekit/whisper/assets/multilingual.tiktoken
Normal file
157
whisperlivekit/whisper/audio.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from subprocess import CalledProcessError, run
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import exact_div
|
||||
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||
|
||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file: str
|
||||
The audio file to open
|
||||
|
||||
sr: int
|
||||
The sample rate to resample the audio if necessary
|
||||
|
||||
Returns
|
||||
-------
|
||||
A NumPy array containing the audio waveform, in float32 dtype.
|
||||
"""
|
||||
|
||||
# This launches a subprocess to decode audio while down-mixing
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
# fmt: off
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sr),
|
||||
"-"
|
||||
]
|
||||
# fmt: on
|
||||
try:
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
except CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
"""
|
||||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||
"""
|
||||
if torch.is_tensor(array):
|
||||
if array.shape[axis] > length:
|
||||
array = array.index_select(
|
||||
dim=axis, index=torch.arange(length, device=array.device)
|
||||
)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
||||
else:
|
||||
if array.shape[axis] > length:
|
||||
array = array.take(indices=range(length), axis=axis)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = np.pad(array, pad_widths)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
|
||||
np.savez_compressed(
|
||||
"mel_filters.npz",
|
||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
||||
)
|
||||
"""
|
||||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||
|
||||
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
with np.load(filters_path, allow_pickle=False) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = 80,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 and 128 are supported
|
||||
|
||||
padding: int
|
||||
Number of zero samples to pad to the right
|
||||
|
||||
device: Optional[Union[str, torch.device]]
|
||||
If given, the audio tensor is moved to this device before STFT
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (n_mels, n_frames)
|
||||
A Tensor that contains the Mel spectrogram
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
if device is not None:
|
||||
audio = audio.to(device)
|
||||
if padding > 0:
|
||||
audio = F.pad(audio, (0, padding))
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
filters = mel_filters(audio.device, n_mels)
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
826
whisperlivekit/whisper/decoding.py
Normal file
@@ -0,0 +1,826 @@
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
from .utils import compression_ratio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def detect_language(
|
||||
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
||||
) -> Tuple[Tensor, List[dict]]:
|
||||
"""
|
||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||
of the most probable language tokens and the probability distribution over all language tokens.
|
||||
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
||||
|
||||
Returns
|
||||
-------
|
||||
language_tokens : Tensor, shape = (n_audio,)
|
||||
ids of the most probable language tokens, which appears after the startoftranscript token.
|
||||
language_probs : List[Dict[str, float]], length = n_audio
|
||||
list of dictionaries containing the probability distribution over all languages.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual, num_languages=model.num_languages
|
||||
)
|
||||
if (
|
||||
tokenizer.language is None
|
||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||
):
|
||||
raise ValueError(
|
||||
"This model doesn't have language tokens so it can't perform lang id"
|
||||
)
|
||||
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
# skip encoder forward pass if already-encoded audio features were given
|
||||
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||
mel = model.encoder(mel)
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
n_audio = mel.shape[0]
|
||||
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||
logits = model.logits(x, mel)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(tokenizer.all_language_tokens)] = False
|
||||
logits[:, mask] = -np.inf
|
||||
language_tokens = logits.argmax(dim=-1)
|
||||
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
|
||||
if single:
|
||||
language_tokens = language_tokens[0]
|
||||
language_probs = language_probs[0]
|
||||
|
||||
return language_tokens, language_probs
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingOptions:
|
||||
# whether to perform X->X "transcribe" or X->English "translate"
|
||||
task: str = "transcribe"
|
||||
|
||||
# language that the audio is in; uses detected language if None
|
||||
language: Optional[str] = None
|
||||
|
||||
# sampling-related options
|
||||
temperature: float = 0.0
|
||||
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
||||
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
||||
|
||||
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
||||
# to select which to return among the beams or best-of-N samples
|
||||
length_penalty: Optional[float] = None
|
||||
|
||||
# text or tokens to feed as the prompt or the prefix; for more info:
|
||||
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
||||
|
||||
# list of tokens ids (or comma-separated token ids) to suppress
|
||||
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
|
||||
# timestamp sampling options
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingResult:
|
||||
audio_features: Tensor
|
||||
language: str
|
||||
language_probs: Optional[Dict[str, float]] = None
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
text: str = ""
|
||||
avg_logprob: float = np.nan
|
||||
no_speech_prob: float = np.nan
|
||||
temperature: float = np.nan
|
||||
compression_ratio: float = np.nan
|
||||
|
||||
|
||||
class Inference:
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rearrange_kv_cache(self, source_indices) -> None:
|
||||
"""Update the key-value cache according to the updated beams"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup_caching(self) -> None:
|
||||
"""Clean up any resources or hooks after decoding is finished"""
|
||||
pass
|
||||
|
||||
|
||||
class PyTorchInference(Inference):
|
||||
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||
self.model: "Whisper" = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
||||
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
||||
self.kv_modules = key_modules + value_modules
|
||||
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
if not self.kv_cache:
|
||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||
|
||||
if tokens.shape[-1] > self.initial_token_length:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens = tokens[:, -1:]
|
||||
|
||||
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
|
||||
def cleanup_caching(self):
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
def rearrange_kv_cache(self, source_indices):
|
||||
if source_indices != list(range(len(source_indices))):
|
||||
for module in self.kv_modules:
|
||||
# update the key/value cache to contain the selected sequences
|
||||
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
||||
|
||||
|
||||
class SequenceRanker:
|
||||
def rank(
|
||||
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
||||
) -> List[int]:
|
||||
"""
|
||||
Given a list of groups of samples and their cumulative log probabilities,
|
||||
return the indices of the samples in each group to select as the final result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MaximumLikelihoodRanker(SequenceRanker):
|
||||
"""
|
||||
Select the sample with the highest log probabilities, penalized using either
|
||||
a simple length normalization or Google NMT paper's length penalty
|
||||
"""
|
||||
|
||||
def __init__(self, length_penalty: Optional[float]):
|
||||
self.length_penalty = length_penalty
|
||||
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
||||
def scores(logprobs, lengths):
|
||||
result = []
|
||||
for logprob, length in zip(logprobs, lengths):
|
||||
if self.length_penalty is None:
|
||||
penalty = length
|
||||
else:
|
||||
# from the Google NMT paper
|
||||
penalty = ((5 + length) / 6) ** self.length_penalty
|
||||
result.append(logprob / penalty)
|
||||
return result
|
||||
|
||||
# get the sequence with the highest score
|
||||
lengths = [[len(t) for t in s] for s in tokens]
|
||||
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
||||
|
||||
|
||||
class TokenDecoder:
|
||||
def reset(self):
|
||||
"""Initialize any stateful variables for decoding a new sequence"""
|
||||
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
"""Specify how to select the next token, based on the current trace and logits
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_batch)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
||||
the tokens, appended with the selected next token
|
||||
|
||||
completed : bool
|
||||
True if all sequences has reached the end of text
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(
|
||||
self, tokens: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||
"""Finalize search and return the final candidate sequences
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
||||
sequence of Tensors containing candidate token sequences, for each audio input
|
||||
|
||||
sum_logprobs : List[List[float]], length = n_audio
|
||||
sequence of cumulative log probabilities corresponding to the above
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GreedyDecoder(TokenDecoder):
|
||||
def __init__(self, temperature: float, eot: int):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
if self.temperature == 0:
|
||||
next_tokens = logits.argmax(dim=-1)
|
||||
else:
|
||||
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||
|
||||
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||
|
||||
completed = (tokens[:, -1] == self.eot).all()
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
||||
# make sure each sequence has at least one EOT token at the end
|
||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||
return tokens, sum_logprobs.tolist()
|
||||
|
||||
|
||||
class BeamSearchDecoder(TokenDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
beam_size: int,
|
||||
eot: int,
|
||||
inference: Inference,
|
||||
patience: Optional[float] = None,
|
||||
):
|
||||
self.beam_size = beam_size
|
||||
self.eot = eot
|
||||
self.inference = inference
|
||||
self.patience = patience or 1.0
|
||||
self.max_candidates: int = round(beam_size * self.patience)
|
||||
self.finished_sequences = None
|
||||
|
||||
assert (
|
||||
self.max_candidates > 0
|
||||
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
|
||||
def reset(self):
|
||||
self.finished_sequences = None
|
||||
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
if tokens.shape[0] % self.beam_size != 0:
|
||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||
|
||||
n_audio = tokens.shape[0] // self.beam_size
|
||||
if self.finished_sequences is None: # for the first update
|
||||
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
next_tokens, source_indices, finished_sequences = [], [], []
|
||||
for i in range(n_audio):
|
||||
scores, sources, finished = {}, {}, {}
|
||||
|
||||
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
||||
for j in range(self.beam_size):
|
||||
idx = i * self.beam_size + j
|
||||
prefix = tokens[idx].tolist()
|
||||
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
||||
new_logprob = (sum_logprobs[idx] + logprob).item()
|
||||
sequence = tuple(prefix + [token.item()])
|
||||
scores[sequence] = new_logprob
|
||||
sources[sequence] = idx
|
||||
|
||||
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
||||
saved = 0
|
||||
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||
if sequence[-1] == self.eot:
|
||||
finished[sequence] = scores[sequence]
|
||||
else:
|
||||
sum_logprobs[len(next_tokens)] = scores[sequence]
|
||||
next_tokens.append(sequence)
|
||||
source_indices.append(sources[sequence])
|
||||
|
||||
saved += 1
|
||||
if saved == self.beam_size:
|
||||
break
|
||||
|
||||
finished_sequences.append(finished)
|
||||
|
||||
tokens = torch.tensor(next_tokens, device=tokens.device)
|
||||
self.inference.rearrange_kv_cache(source_indices)
|
||||
|
||||
# add newly finished sequences to self.finished_sequences
|
||||
assert len(self.finished_sequences) == len(finished_sequences)
|
||||
for previously_finished, newly_finished in zip(
|
||||
self.finished_sequences, finished_sequences
|
||||
):
|
||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||
if len(previously_finished) >= self.max_candidates:
|
||||
break # the candidate list is full
|
||||
previously_finished[seq] = newly_finished[seq]
|
||||
|
||||
# mark as completed if all audio has enough number of samples
|
||||
completed = all(
|
||||
len(sequences) >= self.max_candidates
|
||||
for sequences in self.finished_sequences
|
||||
)
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||
sum_logprobs = sum_logprobs.cpu()
|
||||
for i, sequences in enumerate(self.finished_sequences):
|
||||
if (
|
||||
len(sequences) < self.beam_size
|
||||
): # when not enough sequences are finished
|
||||
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||
if len(sequences) >= self.beam_size:
|
||||
break
|
||||
|
||||
tokens: List[List[Tensor]] = [
|
||||
[torch.tensor(seq) for seq in sequences.keys()]
|
||||
for sequences in self.finished_sequences
|
||||
]
|
||||
sum_logprobs: List[List[float]] = [
|
||||
list(sequences.values()) for sequences in self.finished_sequences
|
||||
]
|
||||
return tokens, sum_logprobs
|
||||
|
||||
|
||||
class LogitFilter:
|
||||
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
||||
"""Apply any filtering or masking to logits in-place
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SuppressBlank(LogitFilter):
|
||||
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
|
||||
|
||||
class SuppressTokens(LogitFilter):
|
||||
def __init__(self, suppress_tokens: Sequence[int]):
|
||||
self.suppress_tokens = list(suppress_tokens)
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
logits[:, self.suppress_tokens] = -np.inf
|
||||
|
||||
|
||||
class ApplyTimestampRules(LogitFilter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
sample_begin: int,
|
||||
max_initial_timestamp_index: Optional[int],
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
self.max_initial_timestamp_index = max_initial_timestamp_index
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
if self.tokenizer.no_timestamps is not None:
|
||||
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
sampled_tokens = tokens[k, self.sample_begin :]
|
||||
seq = [t for t in sampled_tokens.tolist()]
|
||||
last_was_timestamp = (
|
||||
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
penultimate_was_timestamp = (
|
||||
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
|
||||
if last_was_timestamp:
|
||||
if penultimate_was_timestamp: # has to be non-timestamp
|
||||
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
||||
else: # cannot be normal text tokens
|
||||
logits[k, : self.tokenizer.eot] = -np.inf
|
||||
|
||||
timestamps = sampled_tokens[
|
||||
sampled_tokens.ge(self.tokenizer.timestamp_begin)
|
||||
]
|
||||
if timestamps.numel() > 0:
|
||||
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||
# also force each segment to have a nonzero length, to prevent infinite looping
|
||||
if last_was_timestamp and not penultimate_was_timestamp:
|
||||
timestamp_last = timestamps[-1]
|
||||
else:
|
||||
timestamp_last = timestamps[-1] + 1
|
||||
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
|
||||
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
# suppress generating non-timestamp tokens at the beginning
|
||||
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if self.max_initial_timestamp_index is not None:
|
||||
last_allowed = (
|
||||
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
)
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
for k in range(tokens.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
||||
dim=-1
|
||||
)
|
||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
|
||||
class DecodingTask:
|
||||
inference: Inference
|
||||
sequence_ranker: SequenceRanker
|
||||
decoder: TokenDecoder
|
||||
logit_filters: List[LogitFilter]
|
||||
|
||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||
self.model = model
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language=language,
|
||||
task=options.task,
|
||||
)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
self.n_group: int = options.beam_size or options.best_of or 1
|
||||
self.n_ctx: int = model.dims.n_text_ctx
|
||||
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
||||
|
||||
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
||||
if self.options.without_timestamps:
|
||||
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
||||
|
||||
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
||||
self.sample_begin: int = len(self.initial_tokens)
|
||||
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||
|
||||
# inference: implements the forward pass through the decoder, including kv caching
|
||||
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
||||
|
||||
# sequence ranker: implements how to rank a group of sampled sequences
|
||||
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||
|
||||
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||
if options.beam_size is not None:
|
||||
self.decoder = BeamSearchDecoder(
|
||||
options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||
)
|
||||
else:
|
||||
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||
|
||||
# logit filters: applies various rules to suppress or penalize certain tokens
|
||||
self.logit_filters = []
|
||||
if self.options.suppress_blank:
|
||||
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
||||
if self.options.suppress_tokens:
|
||||
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
||||
if not options.without_timestamps:
|
||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||
max_initial_timestamp_index = None
|
||||
if options.max_initial_timestamp:
|
||||
max_initial_timestamp_index = round(
|
||||
self.options.max_initial_timestamp / precision
|
||||
)
|
||||
self.logit_filters.append(
|
||||
ApplyTimestampRules(
|
||||
tokenizer, self.sample_begin, max_initial_timestamp_index
|
||||
)
|
||||
)
|
||||
|
||||
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||
if options.beam_size is not None and options.best_of is not None:
|
||||
raise ValueError("beam_size and best_of can't be given together")
|
||||
if options.temperature == 0:
|
||||
if options.best_of is not None:
|
||||
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||
if options.patience is not None and options.beam_size is None:
|
||||
raise ValueError("patience requires beam_size to be given")
|
||||
if options.length_penalty is not None and not (
|
||||
0 <= options.length_penalty <= 1
|
||||
):
|
||||
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||
|
||||
return options
|
||||
|
||||
def _get_initial_tokens(self) -> Tuple[int]:
|
||||
tokens = list(self.sot_sequence)
|
||||
|
||||
if prefix := self.options.prefix:
|
||||
prefix_tokens = (
|
||||
self.tokenizer.encode(" " + prefix.strip())
|
||||
if isinstance(prefix, str)
|
||||
else prefix
|
||||
)
|
||||
if self.sample_len is not None:
|
||||
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||
tokens = tokens + prefix_tokens
|
||||
|
||||
if prompt := self.options.prompt:
|
||||
prompt_tokens = (
|
||||
self.tokenizer.encode(" " + prompt.strip())
|
||||
if isinstance(prompt, str)
|
||||
else prompt
|
||||
)
|
||||
tokens = (
|
||||
[self.tokenizer.sot_prev]
|
||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||
+ tokens
|
||||
)
|
||||
|
||||
return tuple(tokens)
|
||||
|
||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||
suppress_tokens = self.options.suppress_tokens
|
||||
|
||||
if isinstance(suppress_tokens, str):
|
||||
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
||||
|
||||
if -1 in suppress_tokens:
|
||||
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
||||
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||
suppress_tokens = [] # interpret empty string as an empty list
|
||||
else:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens.extend(
|
||||
[
|
||||
self.tokenizer.transcribe,
|
||||
self.tokenizer.translate,
|
||||
self.tokenizer.sot,
|
||||
self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm,
|
||||
]
|
||||
)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
# no-speech probability is collected separately
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
|
||||
return tuple(sorted(set(suppress_tokens)))
|
||||
|
||||
def _get_audio_features(self, mel: Tensor):
|
||||
if self.options.fp16:
|
||||
mel = mel.half()
|
||||
|
||||
if mel.shape[-2:] == (
|
||||
self.model.dims.n_audio_ctx,
|
||||
self.model.dims.n_audio_state,
|
||||
):
|
||||
# encoded audio features are given; skip audio encoding
|
||||
audio_features = mel
|
||||
else:
|
||||
audio_features = self.model.encoder(mel)
|
||||
|
||||
if audio_features.dtype != (
|
||||
torch.float16 if self.options.fp16 else torch.float32
|
||||
):
|
||||
return TypeError(
|
||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||
)
|
||||
|
||||
return audio_features
|
||||
|
||||
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
||||
languages = [self.options.language] * audio_features.shape[0]
|
||||
lang_probs = None
|
||||
|
||||
if self.options.language is None or self.options.task == "lang_id":
|
||||
lang_tokens, lang_probs = self.model.detect_language(
|
||||
audio_features, self.tokenizer
|
||||
)
|
||||
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||
if self.options.language is None:
|
||||
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||
|
||||
return languages, lang_probs
|
||||
|
||||
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||
n_batch = tokens.shape[0]
|
||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||
no_speech_probs = [np.nan] * n_batch
|
||||
|
||||
try:
|
||||
for i in range(self.sample_len):
|
||||
logits = self.inference.logits(tokens, audio_features)
|
||||
|
||||
if (
|
||||
i == 0 and self.tokenizer.no_speech is not None
|
||||
): # save no_speech_probs
|
||||
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
|
||||
# now we need to consider the logits at the last token only
|
||||
logits = logits[:, -1]
|
||||
|
||||
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||
for logit_filter in self.logit_filters:
|
||||
logit_filter.apply(logits, tokens)
|
||||
|
||||
# expand the tokens tensor with the selected next tokens
|
||||
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
||||
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
finally:
|
||||
self.inference.cleanup_caching()
|
||||
|
||||
return tokens, sum_logprobs, no_speech_probs
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||
self.decoder.reset()
|
||||
tokenizer: Tokenizer = self.tokenizer
|
||||
n_audio: int = mel.shape[0]
|
||||
|
||||
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
||||
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||
|
||||
# detect language if requested, overwriting the language token
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
if self.options.task == "lang_id":
|
||||
return [
|
||||
DecodingResult(
|
||||
audio_features=features, language=language, language_probs=probs
|
||||
)
|
||||
for features, language, probs in zip(
|
||||
audio_features, languages, language_probs
|
||||
)
|
||||
]
|
||||
|
||||
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||
|
||||
# call the main sampling loop
|
||||
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
||||
|
||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||
audio_features = audio_features[:: self.n_group]
|
||||
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||
|
||||
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens: List[List[Tensor]] = [
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||
for s in tokens
|
||||
]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||
|
||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [
|
||||
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
||||
]
|
||||
|
||||
fields = (
|
||||
texts,
|
||||
languages,
|
||||
tokens,
|
||||
audio_features,
|
||||
avg_logprobs,
|
||||
no_speech_probs,
|
||||
)
|
||||
if len(set(map(len, fields))) != 1:
|
||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||
|
||||
return [
|
||||
DecodingResult(
|
||||
audio_features=features,
|
||||
language=language,
|
||||
tokens=tokens,
|
||||
text=text,
|
||||
avg_logprob=avg_logprob,
|
||||
no_speech_prob=no_speech_prob,
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
||||
*fields
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
model: "Whisper",
|
||||
mel: Tensor,
|
||||
options: DecodingOptions = DecodingOptions(),
|
||||
**kwargs,
|
||||
) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
"""
|
||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
the Whisper model instance
|
||||
|
||||
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
||||
A tensor containing the Mel spectrogram(s)
|
||||
|
||||
options: DecodingOptions
|
||||
A dataclass that contains all necessary options for decoding 30-second segments
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: Union[DecodingResult, List[DecodingResult]]
|
||||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||
"""
|
||||
if single := mel.ndim == 2:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
if kwargs:
|
||||
options = replace(options, **kwargs)
|
||||
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
|
||||
return result[0] if single else result
|
||||
350
whisperlivekit/whisper/model.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import base64
|
||||
import gzip
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
|
||||
try:
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
SDPA_AVAILABLE = True
|
||||
except (ImportError, RuntimeError, OSError):
|
||||
scaled_dot_product_attention = None
|
||||
SDPA_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
n_mels: int
|
||||
n_audio_ctx: int
|
||||
n_audio_state: int
|
||||
n_audio_head: int
|
||||
n_audio_layer: int
|
||||
n_vocab: int
|
||||
n_text_ctx: int
|
||||
n_text_state: int
|
||||
n_text_head: int
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x,
|
||||
self.weight.to(x.dtype),
|
||||
None if self.bias is None else self.bias.to(x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(
|
||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
assert channels % 2 == 0
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_sdpa():
|
||||
prev_state = MultiHeadAttention.use_sdpa
|
||||
try:
|
||||
MultiHeadAttention.use_sdpa = False
|
||||
yield
|
||||
finally:
|
||||
MultiHeadAttention.use_sdpa = prev_state
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
use_sdpa = False # Disable SDPA to ensure qk is always computed for hooks
|
||||
|
||||
def __init__(self, n_state: int, n_head: int, cache_id: str = ""):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
self.cache_id = cache_id
|
||||
self.key.cache_id = f"{cache_id}_key"
|
||||
self.value.cache_id = f"{cache_id}_value"
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None or self.key not in kv_cache:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache[self.key]
|
||||
v = kv_cache[self.value]
|
||||
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
|
||||
a = scaled_dot_product_attention(
|
||||
q, k, v, is_causal=mask is not None and n_ctx > 1
|
||||
)
|
||||
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = None
|
||||
else:
|
||||
qk = (q * scale) @ (k * scale).transpose(-1, -2)
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
qk = qk.detach()
|
||||
|
||||
return out, qk
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, cache_id: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_self_attn")
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head, cache_id=f"{cache_id}_cross_attn") if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(
|
||||
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||
)
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, cache_id=f"enc_layer{i}") for i in range(n_layer)]
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True, cache_id=f"dec_layer{i}")
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions, decoder_only: bool = False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
|
||||
if not decoder_only:
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
# use the last half among the decoder layers for time alignment by default;
|
||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||
).copy()
|
||||
mask = torch.from_numpy(array).reshape(
|
||||
self.dims.n_text_layer, self.dims.n_text_head
|
||||
)
|
||||
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(
|
||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab >= 51865
|
||||
|
||||
@property
|
||||
def num_languages(self):
|
||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||
intermediate tensors to be reused during later calculations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cache : Dict[nn.Module, torch.Tensor]
|
||||
A dictionary object mapping the key/value projection modules to its cache
|
||||
hooks : List[RemovableHandle]
|
||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||
"""
|
||||
cache = {**cache} if cache is not None else {}
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||
# save as-is, for the first token or cross attention
|
||||
cache[module] = output
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
def install_hooks(layer: nn.Module):
|
||||
if isinstance(layer, MultiHeadAttention):
|
||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||
|
||||
self.decoder.apply(install_hooks)
|
||||
return cache, hooks
|
||||
|
||||
detect_language = detect_language_function
|
||||
transcribe = transcribe_function
|
||||
decode = decode_function
|
||||
2
whisperlivekit/whisper/normalizers/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
||||
80
whisperlivekit/whisper/normalizers/basic.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import regex
|
||||
|
||||
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||
ADDITIONAL_DIACRITICS = {
|
||||
"œ": "oe",
|
||||
"Œ": "OE",
|
||||
"ø": "o",
|
||||
"Ø": "O",
|
||||
"æ": "ae",
|
||||
"Æ": "AE",
|
||||
"ß": "ss",
|
||||
"ẞ": "SS",
|
||||
"đ": "d",
|
||||
"Đ": "D",
|
||||
"ð": "d",
|
||||
"Ð": "D",
|
||||
"þ": "th",
|
||||
"Þ": "th",
|
||||
"ł": "l",
|
||||
"Ł": "L",
|
||||
}
|
||||
|
||||
|
||||
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||
"""
|
||||
Replace any other markers, symbols, and punctuations with a space,
|
||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||
"""
|
||||
return "".join(
|
||||
(
|
||||
c
|
||||
if c in keep
|
||||
else (
|
||||
ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else (
|
||||
""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " " if unicodedata.category(c)[0] in "MSP" else c
|
||||
)
|
||||
)
|
||||
)
|
||||
for c in unicodedata.normalize("NFKD", s)
|
||||
)
|
||||
|
||||
|
||||
def remove_symbols(s: str):
|
||||
"""
|
||||
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||
"""
|
||||
return "".join(
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c
|
||||
for c in unicodedata.normalize("NFKC", s)
|
||||
)
|
||||
|
||||
|
||||
class BasicTextNormalizer:
|
||||
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||
self.clean = (
|
||||
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
)
|
||||
self.split_letters = split_letters
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = self.clean(s).lower()
|
||||
|
||||
if self.split_letters:
|
||||
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||
|
||||
s = re.sub(
|
||||
r"\s+", " ", s
|
||||
) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
||||
1741
whisperlivekit/whisper/normalizers/english.json
Normal file
550
whisperlivekit/whisper/normalizers/english.py
Normal file
@@ -0,0 +1,550 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from fractions import Fraction
|
||||
from typing import Iterator, List, Match, Optional, Union
|
||||
|
||||
from more_itertools import windowed
|
||||
|
||||
from .basic import remove_symbols_and_diacritics
|
||||
|
||||
|
||||
class EnglishNumberNormalizer:
|
||||
"""
|
||||
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||
|
||||
- remove any commas
|
||||
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||
- spell out `one` and `ones`
|
||||
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.zeros = {"o", "oh", "zero"}
|
||||
self.ones = {
|
||||
name: i
|
||||
for i, name in enumerate(
|
||||
[
|
||||
"one",
|
||||
"two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
"six",
|
||||
"seven",
|
||||
"eight",
|
||||
"nine",
|
||||
"ten",
|
||||
"eleven",
|
||||
"twelve",
|
||||
"thirteen",
|
||||
"fourteen",
|
||||
"fifteen",
|
||||
"sixteen",
|
||||
"seventeen",
|
||||
"eighteen",
|
||||
"nineteen",
|
||||
],
|
||||
start=1,
|
||||
)
|
||||
}
|
||||
self.ones_plural = {
|
||||
"sixes" if name == "six" else name + "s": (value, "s")
|
||||
for name, value in self.ones.items()
|
||||
}
|
||||
self.ones_ordinal = {
|
||||
"zeroth": (0, "th"),
|
||||
"first": (1, "st"),
|
||||
"second": (2, "nd"),
|
||||
"third": (3, "rd"),
|
||||
"fifth": (5, "th"),
|
||||
"twelfth": (12, "th"),
|
||||
**{
|
||||
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||
for name, value in self.ones.items()
|
||||
if value > 3 and value != 5 and value != 12
|
||||
},
|
||||
}
|
||||
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||
|
||||
self.tens = {
|
||||
"twenty": 20,
|
||||
"thirty": 30,
|
||||
"forty": 40,
|
||||
"fifty": 50,
|
||||
"sixty": 60,
|
||||
"seventy": 70,
|
||||
"eighty": 80,
|
||||
"ninety": 90,
|
||||
}
|
||||
self.tens_plural = {
|
||||
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_ordinal = {
|
||||
name.replace("y", "ieth"): (value, "th")
|
||||
for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||
|
||||
self.multipliers = {
|
||||
"hundred": 100,
|
||||
"thousand": 1_000,
|
||||
"million": 1_000_000,
|
||||
"billion": 1_000_000_000,
|
||||
"trillion": 1_000_000_000_000,
|
||||
"quadrillion": 1_000_000_000_000_000,
|
||||
"quintillion": 1_000_000_000_000_000_000,
|
||||
"sextillion": 1_000_000_000_000_000_000_000,
|
||||
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||
}
|
||||
self.multipliers_plural = {
|
||||
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_ordinal = {
|
||||
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_suffixed = {
|
||||
**self.multipliers_plural,
|
||||
**self.multipliers_ordinal,
|
||||
}
|
||||
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||
|
||||
self.preceding_prefixers = {
|
||||
"minus": "-",
|
||||
"negative": "-",
|
||||
"plus": "+",
|
||||
"positive": "+",
|
||||
}
|
||||
self.following_prefixers = {
|
||||
"pound": "£",
|
||||
"pounds": "£",
|
||||
"euro": "€",
|
||||
"euros": "€",
|
||||
"dollar": "$",
|
||||
"dollars": "$",
|
||||
"cent": "¢",
|
||||
"cents": "¢",
|
||||
}
|
||||
self.prefixes = set(
|
||||
list(self.preceding_prefixers.values())
|
||||
+ list(self.following_prefixers.values())
|
||||
)
|
||||
self.suffixers = {
|
||||
"per": {"cent": "%"},
|
||||
"percent": "%",
|
||||
}
|
||||
self.specials = {"and", "double", "triple", "point"}
|
||||
|
||||
self.words = set(
|
||||
[
|
||||
key
|
||||
for mapping in [
|
||||
self.zeros,
|
||||
self.ones,
|
||||
self.ones_suffixed,
|
||||
self.tens,
|
||||
self.tens_suffixed,
|
||||
self.multipliers,
|
||||
self.multipliers_suffixed,
|
||||
self.preceding_prefixers,
|
||||
self.following_prefixers,
|
||||
self.suffixers,
|
||||
self.specials,
|
||||
]
|
||||
for key in mapping
|
||||
]
|
||||
)
|
||||
self.literal_words = {"one", "ones"}
|
||||
|
||||
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||
prefix: Optional[str] = None
|
||||
value: Optional[Union[str, int]] = None
|
||||
skip = False
|
||||
|
||||
def to_fraction(s: str):
|
||||
try:
|
||||
return Fraction(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def output(result: Union[str, int]):
|
||||
nonlocal prefix, value
|
||||
result = str(result)
|
||||
if prefix is not None:
|
||||
result = prefix + result
|
||||
value = None
|
||||
prefix = None
|
||||
return result
|
||||
|
||||
if len(words) == 0:
|
||||
return
|
||||
|
||||
for prev, current, next in windowed([None] + words + [None], 3):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||
has_prefix = current[0] in self.prefixes
|
||||
current_without_prefix = current[1:] if has_prefix else current
|
||||
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||
# arabic numbers (potentially with signs and fractions)
|
||||
f = to_fraction(current_without_prefix)
|
||||
assert f is not None
|
||||
if value is not None:
|
||||
if isinstance(value, str) and value.endswith("."):
|
||||
# concatenate decimals / ip address components
|
||||
value = str(value) + str(current)
|
||||
continue
|
||||
else:
|
||||
yield output(value)
|
||||
|
||||
prefix = current[0] if has_prefix else prefix
|
||||
if f.denominator == 1:
|
||||
value = f.numerator # store integers as int
|
||||
else:
|
||||
value = current_without_prefix
|
||||
elif current not in self.words:
|
||||
# non-numeric words
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current in self.zeros:
|
||||
value = str(value or "") + "0"
|
||||
elif current in self.ones:
|
||||
ones = self.ones[current]
|
||||
|
||||
if value is None:
|
||||
value = ones
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if (
|
||||
prev in self.tens and ones < 10
|
||||
): # replace the last zero with the digit
|
||||
assert value[-1] == "0"
|
||||
value = value[:-1] + str(ones)
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif current in self.ones_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
ones, suffix = self.ones_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(ones) + suffix)
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10:
|
||||
assert value[-1] == "0"
|
||||
yield output(value[:-1] + str(ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
value = None
|
||||
elif current in self.tens:
|
||||
tens = self.tens[current]
|
||||
if value is None:
|
||||
value = tens
|
||||
elif isinstance(value, str):
|
||||
value = str(value) + str(tens)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
value += tens
|
||||
else:
|
||||
value = str(value) + str(tens)
|
||||
elif current in self.tens_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
tens, suffix = self.tens_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(tens) + suffix)
|
||||
elif isinstance(value, str):
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + tens) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
elif current in self.multipliers:
|
||||
multiplier = self.multipliers[current]
|
||||
if value is None:
|
||||
value = multiplier
|
||||
elif isinstance(value, str) or value == 0:
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
value = p.numerator
|
||||
else:
|
||||
yield output(value)
|
||||
value = multiplier
|
||||
else:
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
elif current in self.multipliers_suffixed:
|
||||
multiplier, suffix = self.multipliers_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(multiplier) + suffix)
|
||||
elif isinstance(value, str):
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
yield output(str(p.numerator) + suffix)
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(str(multiplier) + suffix)
|
||||
else: # int
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
yield output(str(value) + suffix)
|
||||
value = None
|
||||
elif current in self.preceding_prefixers:
|
||||
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
if next in self.words or next_is_numeric:
|
||||
prefix = self.preceding_prefixers[current]
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.following_prefixers:
|
||||
# apply prefix (dollars, cents, etc.) only after a number
|
||||
if value is not None:
|
||||
prefix = self.following_prefixers[current]
|
||||
yield output(value)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.suffixers:
|
||||
# apply suffix symbols (percent -> '%')
|
||||
if value is not None:
|
||||
suffix = self.suffixers[current]
|
||||
if isinstance(suffix, dict):
|
||||
if next in suffix:
|
||||
yield output(str(value) + suffix[next])
|
||||
skip = True
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
else:
|
||||
yield output(str(value) + suffix)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.specials:
|
||||
if next not in self.words and not next_is_numeric:
|
||||
# apply special handling only if the next word can be numeric
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "and":
|
||||
# ignore "and" after hundreds, thousands, etc.
|
||||
if prev not in self.multipliers:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "double" or current == "triple":
|
||||
if next in self.ones or next in self.zeros:
|
||||
repeats = 2 if current == "double" else 3
|
||||
ones = self.ones.get(next, 0)
|
||||
value = str(value or "") + str(ones) * repeats
|
||||
skip = True
|
||||
else:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "point":
|
||||
if next in self.decimals or next_is_numeric:
|
||||
value = str(value or "") + "."
|
||||
else:
|
||||
# should all have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
else:
|
||||
# all should have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
def preprocess(self, s: str):
|
||||
# replace "<number> and a half" with "<number> point five"
|
||||
results = []
|
||||
|
||||
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||
for i, segment in enumerate(segments):
|
||||
if len(segment.strip()) == 0:
|
||||
continue
|
||||
if i == len(segments) - 1:
|
||||
results.append(segment)
|
||||
else:
|
||||
results.append(segment)
|
||||
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||
if last_word in self.decimals or last_word in self.multipliers:
|
||||
results.append("point five")
|
||||
else:
|
||||
results.append("and a half")
|
||||
|
||||
s = " ".join(results)
|
||||
|
||||
# put a space at number/letter boundary
|
||||
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||
|
||||
# but remove spaces which could be a suffix
|
||||
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||
|
||||
return s
|
||||
|
||||
def postprocess(self, s: str):
|
||||
def combine_cents(m: Match):
|
||||
try:
|
||||
currency = m.group(1)
|
||||
integer = m.group(2)
|
||||
cents = int(m.group(3))
|
||||
return f"{currency}{integer}.{cents:02d}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
def extract_cents(m: Match):
|
||||
try:
|
||||
return f"¢{int(m.group(1))}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||
|
||||
# write "one(s)" instead of "1(s)", just for the readability
|
||||
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||
|
||||
return s
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = self.preprocess(s)
|
||||
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||
s = self.postprocess(s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class EnglishSpellingNormalizer:
|
||||
"""
|
||||
Applies British-American spelling mappings as listed in [1].
|
||||
|
||||
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||
self.mapping = json.load(open(mapping_path))
|
||||
|
||||
def __call__(self, s: str):
|
||||
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||
|
||||
|
||||
class EnglishTextNormalizer:
|
||||
def __init__(self):
|
||||
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||
self.replacers = {
|
||||
# common contractions
|
||||
r"\bwon't\b": "will not",
|
||||
r"\bcan't\b": "can not",
|
||||
r"\blet's\b": "let us",
|
||||
r"\bain't\b": "aint",
|
||||
r"\by'all\b": "you all",
|
||||
r"\bwanna\b": "want to",
|
||||
r"\bgotta\b": "got to",
|
||||
r"\bgonna\b": "going to",
|
||||
r"\bi'ma\b": "i am going to",
|
||||
r"\bimma\b": "i am going to",
|
||||
r"\bwoulda\b": "would have",
|
||||
r"\bcoulda\b": "could have",
|
||||
r"\bshoulda\b": "should have",
|
||||
r"\bma'am\b": "madam",
|
||||
# contractions in titles/prefixes
|
||||
r"\bmr\b": "mister ",
|
||||
r"\bmrs\b": "missus ",
|
||||
r"\bst\b": "saint ",
|
||||
r"\bdr\b": "doctor ",
|
||||
r"\bprof\b": "professor ",
|
||||
r"\bcapt\b": "captain ",
|
||||
r"\bgov\b": "governor ",
|
||||
r"\bald\b": "alderman ",
|
||||
r"\bgen\b": "general ",
|
||||
r"\bsen\b": "senator ",
|
||||
r"\brep\b": "representative ",
|
||||
r"\bpres\b": "president ",
|
||||
r"\brev\b": "reverend ",
|
||||
r"\bhon\b": "honorable ",
|
||||
r"\basst\b": "assistant ",
|
||||
r"\bassoc\b": "associate ",
|
||||
r"\blt\b": "lieutenant ",
|
||||
r"\bcol\b": "colonel ",
|
||||
r"\bjr\b": "junior ",
|
||||
r"\bsr\b": "senior ",
|
||||
r"\besq\b": "esquire ",
|
||||
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||
r"'d been\b": " had been",
|
||||
r"'s been\b": " has been",
|
||||
r"'d gone\b": " had gone",
|
||||
r"'s gone\b": " has gone",
|
||||
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||
r"'s got\b": " has got",
|
||||
# general contractions
|
||||
r"n't\b": " not",
|
||||
r"'re\b": " are",
|
||||
r"'s\b": " is",
|
||||
r"'d\b": " would",
|
||||
r"'ll\b": " will",
|
||||
r"'t\b": " not",
|
||||
r"'ve\b": " have",
|
||||
r"'m\b": " am",
|
||||
}
|
||||
self.standardize_numbers = EnglishNumberNormalizer()
|
||||
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = re.sub(self.ignore_patterns, "", s)
|
||||
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
||||
|
||||
for pattern, replacement in self.replacers.items():
|
||||
s = re.sub(pattern, replacement, s)
|
||||
|
||||
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
||||
|
||||
s = self.standardize_numbers(s)
|
||||
s = self.standardize_spellings(s)
|
||||
|
||||
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
||||
|
||||
return s
|
||||
388
whisperlivekit/whisper/timing.py
Normal file
@@ -0,0 +1,388 @@
|
||||
import itertools
|
||||
import subprocess
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def median_filter(x: torch.Tensor, filter_width: int):
|
||||
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||||
pad_width = filter_width // 2
|
||||
if x.shape[-1] <= pad_width:
|
||||
# F.pad requires the padding width to be smaller than the input dimension
|
||||
return x
|
||||
|
||||
if (ndim := x.ndim) <= 2:
|
||||
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
||||
x = x[None, None, :]
|
||||
|
||||
assert (
|
||||
filter_width > 0 and filter_width % 2 == 1
|
||||
), "`filter_width` should be an odd number"
|
||||
|
||||
result = None
|
||||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
||||
if x.is_cuda:
|
||||
try:
|
||||
from .triton_ops import median_filter_cuda
|
||||
|
||||
result = median_filter_cuda(x, filter_width)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower median kernel implementation..."
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
||||
|
||||
if ndim <= 2:
|
||||
result = result[0, 0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@numba.jit(nopython=True)
|
||||
def backtrace(trace: np.ndarray):
|
||||
i = trace.shape[0] - 1
|
||||
j = trace.shape[1] - 1
|
||||
trace[0, :] = 2
|
||||
trace[:, 0] = 1
|
||||
|
||||
result = []
|
||||
while i > 0 or j > 0:
|
||||
result.append((i - 1, j - 1))
|
||||
|
||||
if trace[i, j] == 0:
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif trace[i, j] == 1:
|
||||
i -= 1
|
||||
elif trace[i, j] == 2:
|
||||
j -= 1
|
||||
else:
|
||||
raise ValueError("Unexpected trace[i, j]")
|
||||
|
||||
result = np.array(result)
|
||||
return result[::-1, :].T
|
||||
|
||||
|
||||
@numba.jit(nopython=True, parallel=True)
|
||||
def dtw_cpu(x: np.ndarray):
|
||||
N, M = x.shape
|
||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||
|
||||
cost[0, 0] = 0
|
||||
for j in range(1, M + 1):
|
||||
for i in range(1, N + 1):
|
||||
c0 = cost[i - 1, j - 1]
|
||||
c1 = cost[i - 1, j]
|
||||
c2 = cost[i, j - 1]
|
||||
|
||||
if c0 < c1 and c0 < c2:
|
||||
c, t = c0, 0
|
||||
elif c1 < c0 and c1 < c2:
|
||||
c, t = c1, 1
|
||||
else:
|
||||
c, t = c2, 2
|
||||
|
||||
cost[i, j] = x[i - 1, j - 1] + c
|
||||
trace[i, j] = t
|
||||
|
||||
return backtrace(trace)
|
||||
|
||||
|
||||
def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
from .triton_ops import dtw_kernel
|
||||
|
||||
M, N = x.shape
|
||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||
|
||||
x_skew = (
|
||||
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
)
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.to(x.device)
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
cost,
|
||||
trace,
|
||||
x_skew,
|
||||
x_skew.stride(0),
|
||||
cost.stride(0),
|
||||
trace.stride(0),
|
||||
N,
|
||||
M,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
||||
:, : N + 1
|
||||
]
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||
if x.is_cuda:
|
||||
try:
|
||||
return dtw_cuda(x)
|
||||
except (RuntimeError, subprocess.CalledProcessError):
|
||||
warnings.warn(
|
||||
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
||||
"falling back to a slower DTW implementation..."
|
||||
)
|
||||
|
||||
return dtw_cpu(x.double().cpu().numpy())
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordTiming:
|
||||
word: str
|
||||
tokens: List[int]
|
||||
start: float
|
||||
end: float
|
||||
probability: float
|
||||
|
||||
|
||||
def find_alignment(
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
*,
|
||||
medfilt_width: int = 7,
|
||||
qk_scale: float = 1.0,
|
||||
) -> List[WordTiming]:
|
||||
if len(text_tokens) == 0:
|
||||
return []
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*text_tokens,
|
||||
tokenizer.eot,
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
# install hooks on the cross attention layers to retrieve the attention weights
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
from .model import disable_sdpa
|
||||
|
||||
with torch.no_grad(), disable_sdpa():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||
text_token_probs = text_token_probs.tolist()
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = (weights - mean) / std
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||
text_indices, time_indices = dtw(-matrix)
|
||||
|
||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||
if len(word_tokens) <= 1:
|
||||
# return on eot only
|
||||
# >>> np.pad([], (1, 0))
|
||||
# array([0.])
|
||||
# This results in crashes when we lookup jump_times with float, like
|
||||
# IndexError: arrays used as indices must be of integer (or boolean) type
|
||||
return []
|
||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||
start_times = jump_times[word_boundaries[:-1]]
|
||||
end_times = jump_times[word_boundaries[1:]]
|
||||
word_probabilities = [
|
||||
np.mean(text_token_probs[i:j])
|
||||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
]
|
||||
|
||||
return [
|
||||
WordTiming(word, tokens, start, end, probability)
|
||||
for word, tokens, start, end, probability in zip(
|
||||
words, word_tokens, start_times, end_times, word_probabilities
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
||||
# merge prepended punctuations
|
||||
i = len(alignment) - 2
|
||||
j = len(alignment) - 1
|
||||
while i >= 0:
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
||||
# prepend it to the following word
|
||||
following.word = previous.word + following.word
|
||||
following.tokens = previous.tokens + following.tokens
|
||||
previous.word = ""
|
||||
previous.tokens = []
|
||||
else:
|
||||
j = i
|
||||
i -= 1
|
||||
|
||||
# merge appended punctuations
|
||||
i = 0
|
||||
j = 1
|
||||
while j < len(alignment):
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if not previous.word.endswith(" ") and following.word in appended:
|
||||
# append it to the previous word
|
||||
previous.word = previous.word + following.word
|
||||
previous.tokens = previous.tokens + following.tokens
|
||||
following.word = ""
|
||||
following.tokens = []
|
||||
else:
|
||||
i = j
|
||||
j += 1
|
||||
|
||||
|
||||
def add_word_timestamps(
|
||||
*,
|
||||
segments: List[dict],
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
last_speech_timestamp: float,
|
||||
**kwargs,
|
||||
):
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
text_tokens_per_segment = [
|
||||
[token for token in segment["tokens"] if token < tokenizer.eot]
|
||||
for segment in segments
|
||||
]
|
||||
|
||||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||
word_durations = np.array([t.end - t.start for t in alignment])
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||
median_duration = min(0.7, float(median_duration))
|
||||
max_duration = median_duration * 2
|
||||
|
||||
# hack: truncate long words at sentence boundaries.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
if len(word_durations) > 0:
|
||||
sentence_end_marks = ".。!!??"
|
||||
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
||||
for i in range(1, len(alignment)):
|
||||
if alignment[i].end - alignment[i].start > max_duration:
|
||||
if alignment[i].word in sentence_end_marks:
|
||||
alignment[i].end = alignment[i].start + max_duration
|
||||
elif alignment[i - 1].word in sentence_end_marks:
|
||||
alignment[i].start = alignment[i].end - max_duration
|
||||
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||
word_index = 0
|
||||
|
||||
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
||||
saved_tokens = 0
|
||||
words = []
|
||||
|
||||
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
||||
timing = alignment[word_index]
|
||||
|
||||
if timing.word:
|
||||
words.append(
|
||||
dict(
|
||||
word=timing.word,
|
||||
start=round(time_offset + timing.start, 2),
|
||||
end=round(time_offset + timing.end, 2),
|
||||
probability=timing.probability,
|
||||
)
|
||||
)
|
||||
|
||||
saved_tokens += len(timing.tokens)
|
||||
word_index += 1
|
||||
|
||||
# hack: truncate long words at segment boundaries.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
if len(words) > 0:
|
||||
# ensure the first and second word after a pause is not longer than
|
||||
# twice the median word duration.
|
||||
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
||||
words[0]["end"] - words[0]["start"] > max_duration
|
||||
or (
|
||||
len(words) > 1
|
||||
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
||||
)
|
||||
):
|
||||
if (
|
||||
len(words) > 1
|
||||
and words[1]["end"] - words[1]["start"] > max_duration
|
||||
):
|
||||
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
||||
words[0]["end"] = words[1]["start"] = boundary
|
||||
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
||||
|
||||
# prefer the segment-level start timestamp if the first word is too long.
|
||||
if (
|
||||
segment["start"] < words[0]["end"]
|
||||
and segment["start"] - 0.5 > words[0]["start"]
|
||||
):
|
||||
words[0]["start"] = max(
|
||||
0, min(words[0]["end"] - median_duration, segment["start"])
|
||||
)
|
||||
else:
|
||||
segment["start"] = words[0]["start"]
|
||||
|
||||
# prefer the segment-level end timestamp if the last word is too long.
|
||||
if (
|
||||
segment["end"] > words[-1]["start"]
|
||||
and segment["end"] + 0.5 < words[-1]["end"]
|
||||
):
|
||||
words[-1]["end"] = max(
|
||||
words[-1]["start"] + median_duration, segment["end"]
|
||||
)
|
||||
else:
|
||||
segment["end"] = words[-1]["end"]
|
||||
|
||||
last_speech_timestamp = segment["end"]
|
||||
|
||||
segment["words"] = words
|
||||
395
whisperlivekit/whisper/tokenizer.py
Normal file
@@ -0,0 +1,395 @@
|
||||
import base64
|
||||
import os
|
||||
import string
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import tiktoken
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
"mandarin": "zh",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tokenizer:
|
||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||
|
||||
encoding: tiktoken.Encoding
|
||||
num_languages: int
|
||||
language: Optional[str] = None
|
||||
task: Optional[str] = None
|
||||
sot_sequence: Tuple[int] = ()
|
||||
special_tokens: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for special in self.encoding.special_tokens_set:
|
||||
special_token = self.encoding.encode_single_token(special)
|
||||
self.special_tokens[special] = special_token
|
||||
|
||||
sot: int = self.special_tokens["<|startoftranscript|>"]
|
||||
translate: int = self.special_tokens["<|translate|>"]
|
||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
||||
sot_sequence = [sot]
|
||||
if self.language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||
if self.task is not None:
|
||||
task_token: int = transcribe if self.task == "transcribe" else translate
|
||||
sot_sequence.append(task_token)
|
||||
|
||||
self.sot_sequence = tuple(sot_sequence)
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self.encoding.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: List[int], **kwargs) -> str:
|
||||
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
@cached_property
|
||||
def eot(self) -> int:
|
||||
return self.encoding.eot_token
|
||||
|
||||
@cached_property
|
||||
def transcribe(self) -> int:
|
||||
return self.special_tokens["<|transcribe|>"]
|
||||
|
||||
@cached_property
|
||||
def translate(self) -> int:
|
||||
return self.special_tokens["<|translate|>"]
|
||||
|
||||
@cached_property
|
||||
def sot(self) -> int:
|
||||
return self.special_tokens["<|startoftranscript|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_lm(self) -> int:
|
||||
return self.special_tokens["<|startoflm|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_prev(self) -> int:
|
||||
return self.special_tokens["<|startofprev|>"]
|
||||
|
||||
@cached_property
|
||||
def no_speech(self) -> int:
|
||||
return self.special_tokens["<|nospeech|>"]
|
||||
|
||||
@cached_property
|
||||
def no_timestamps(self) -> int:
|
||||
return self.special_tokens["<|notimestamps|>"]
|
||||
|
||||
@cached_property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.special_tokens["<|0.00|>"]
|
||||
|
||||
@cached_property
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError("This tokenizer does not have language token configured")
|
||||
|
||||
return self.to_language_token(self.language)
|
||||
|
||||
def to_language_token(self, language):
|
||||
if token := self.special_tokens.get(f"<|{language}|>", None):
|
||||
return token
|
||||
|
||||
raise KeyError(f"Language {language} not found in tokenizer.")
|
||||
|
||||
@cached_property
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in self.special_tokens.items():
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)[: self.num_languages]
|
||||
|
||||
@cached_property
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
||||
|
||||
@cached_property
|
||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||
|
||||
@cached_property
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||
|
||||
- ♪♪♪
|
||||
- ( SPEAKING FOREIGN LANGUAGE )
|
||||
- [DAVID] Hey there,
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||
symbols += (
|
||||
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
)
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [
|
||||
self.encoding.encode(symbol),
|
||||
self.encoding.encode(" " + symbol),
|
||||
]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def split_to_word_tokens(self, tokens: List[int]):
|
||||
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
||||
# These languages don't typically use spaces, so it is difficult to split words
|
||||
# without morpheme analysis. Here, we instead split words at any
|
||||
# position where the tokens are decoded as valid unicode points
|
||||
return self.split_tokens_on_unicode(tokens)
|
||||
|
||||
return self.split_tokens_on_spaces(tokens)
|
||||
|
||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||
decoded_full = self.decode_with_timestamps(tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
words = []
|
||||
word_tokens = []
|
||||
current_tokens = []
|
||||
unicode_offset = 0
|
||||
|
||||
for token in tokens:
|
||||
current_tokens.append(token)
|
||||
decoded = self.decode_with_timestamps(current_tokens)
|
||||
|
||||
if (
|
||||
replacement_char not in decoded
|
||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||
== replacement_char
|
||||
):
|
||||
words.append(decoded)
|
||||
word_tokens.append(current_tokens)
|
||||
current_tokens = []
|
||||
unicode_offset += len(decoded)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
def split_tokens_on_spaces(self, tokens: List[int]):
|
||||
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
||||
words = []
|
||||
word_tokens = []
|
||||
|
||||
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
||||
special = subword_tokens[0] >= self.eot
|
||||
with_space = subword.startswith(" ")
|
||||
punctuation = subword.strip() in string.punctuation
|
||||
if special or with_space or punctuation or len(words) == 0:
|
||||
words.append(subword)
|
||||
word_tokens.append(subword_tokens)
|
||||
else:
|
||||
words[-1] = words[-1] + subword
|
||||
word_tokens[-1].extend(subword_tokens)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
special_tokens = {}
|
||||
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
|
||||
for token in specials:
|
||||
special_tokens[token] = n_vocab
|
||||
n_vocab += 1
|
||||
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
num_languages: int = 99,
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
if language not in LANGUAGES:
|
||||
if language in TO_LANGUAGE_CODE:
|
||||
language = TO_LANGUAGE_CODE[language]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
encoding_name = "multilingual"
|
||||
language = language or "en"
|
||||
task = task or "transcribe"
|
||||
else:
|
||||
encoding_name = "gpt2"
|
||||
language = None
|
||||
task = None
|
||||
|
||||
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||
|
||||
return Tokenizer(
|
||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||
)
|
||||