From b69eaf82be6a051d664e5b0778b9dc6a8852ba6d Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 15 Mar 2026 18:30:00 +0100 Subject: [PATCH] qwen3 simul+kv: optimized streaming with kv cache reuse --- scripts/alignment_heads_qwen3_asr_0.6B.json | 3346 +++++++++++++++++++ whisperlivekit/cascade_bridge.py | 116 + whisperlivekit/core.py | 12 + whisperlivekit/qwen3_simul_kv.py | 787 +++++ 4 files changed, 4261 insertions(+) create mode 100644 scripts/alignment_heads_qwen3_asr_0.6B.json create mode 100644 whisperlivekit/cascade_bridge.py create mode 100644 whisperlivekit/qwen3_simul_kv.py diff --git a/scripts/alignment_heads_qwen3_asr_0.6B.json b/scripts/alignment_heads_qwen3_asr_0.6B.json new file mode 100644 index 0000000..879692c --- /dev/null +++ b/scripts/alignment_heads_qwen3_asr_0.6B.json @@ -0,0 +1,3346 @@ +{ + "model": "Qwen/Qwen3-ASR-0.6B", + "language": "English", + "num_layers": 28, + "num_heads": 16, + "num_kv_heads": 8, + "num_samples": 30, + "total_alignable_tokens": 533, + "ts_threshold": 0.1, + "ts_matrix": [ + [ + 0.08067542213883677, + 0.0825515947467167, + 0.11819887429643527, + 0.1575984990619137, + 0.04127579737335835, + 0.04878048780487805, + 0.009380863039399626, + 0.09193245778611632, + 0.028142589118198873, + 0.08818011257035648, + 0.08442776735459662, + 0.08818011257035648, + 0.043151969981238276, + 0.0150093808630394, + 0.058161350844277676, + 0.0525328330206379 + ], + [ + 0.075046904315197, + 0.0900562851782364, + 0.08067542213883677, + 0.14634146341463414, + 0.06566604127579738, + 0.020637898686679174, + 0.013133208255159476, + 0.0225140712945591, + 0.2870544090056285, + 0.0225140712945591, + 0.043151969981238276, + 0.0225140712945591, + 0.009380863039399626, + 0.0600375234521576, + 0.0975609756097561, + 0.150093808630394 + ], + [ + 0.07129455909943715, + 0.04878048780487805, + 0.10881801125703565, + 0.6772983114446529, + 0.03564727954971857, + 0.0450281425891182, + 0.19136960600375236, + 0.01876172607879925, + 0.15572232645403378, + 0.0975609756097561, + 0.6960600375234521, + 0.7617260787992496, + 0.0825515947467167, + 0.07129455909943715, + 0.24202626641651032, + 0.01125703564727955 + ], + [ + 0.07692307692307693, + 0.0225140712945591, + 0.17636022514071295, + 0.17823639774859287, + 0.324577861163227, + 0.08818011257035648, + 0.11069418386491557, + 0.0675422138836773, + 0.13883677298311445, + 0.09380863039399624, + 0.797373358348968, + 0.6848030018761726, + 0.0450281425891182, + 0.2776735459662289, + 0.26454033771106944, + 0.18761726078799248 + ], + [ + 0.04127579737335835, + 0.06566604127579738, + 0.10881801125703565, + 0.0900562851782364, + 0.17448405253283303, + 0.043151969981238276, + 0.0300187617260788, + 0.09380863039399624, + 0.15196998123827393, + 0.11632270168855535, + 0.34709193245778613, + 0.24202626641651032, + 0.6041275797373359, + 0.7467166979362101, + 0.09943714821763602, + 0.32082551594746717 + ], + [ + 0.12195121951219512, + 0.15384615384615385, + 0.10881801125703565, + 0.075046904315197, + 0.23827392120075047, + 0.34896810506566606, + 0.09943714821763602, + 0.10881801125703565, + 0.19887429643527205, + 0.1050656660412758, + 0.5234521575984991, + 0.14634146341463414, + 0.020637898686679174, + 0.03377110694183865, + 0.14634146341463414, + 0.3621013133208255 + ], + [ + 0.275797373358349, + 0.2551594746716698, + 0.06378986866791744, + 0.11444652908067542, + 0.21200750469043153, + 0.18198874296435272, + 0.8086303939962477, + 0.8198874296435272, + 0.0375234521575985, + 0.3076923076923077, + 0.7879924953095685, + 0.8067542213883677, + 0.726078799249531, + 0.799249530956848, + 0.2795497185741088, + 0.22326454033771106 + ], + [ + 0.4352720450281426, + 0.03377110694183865, + 0.06378986866791744, + 0.075046904315197, + 0.3789868667917448, + 0.26454033771106944, + 0.23076923076923078, + 0.05628517823639775, + 0.058161350844277676, + 0.0450281425891182, + 0.09943714821763602, + 0.150093808630394, + 0.17073170731707318, + 0.21200750469043153, + 0.1425891181988743, + 0.1125703564727955 + ], + [ + 0.1651031894934334, + 0.6904315196998124, + 0.324577861163227, + 0.07692307692307693, + 0.6060037523452158, + 0.3076923076923077, + 0.30393996247654786, + 0.35834896810506567, + 0.0975609756097561, + 0.15947467166979362, + 0.14071294559099437, + 0.14446529080675422, + 0.11069418386491557, + 0.1726078799249531, + 0.35834896810506567, + 0.07129455909943715 + ], + [ + 0.2551594746716698, + 0.058161350844277676, + 0.25328330206378985, + 0.15384615384615385, + 0.24577861163227016, + 0.2551594746716698, + 0.028142589118198873, + 0.2701688555347092, + 0.3771106941838649, + 0.324577861163227, + 0.18198874296435272, + 0.10694183864915573, + 0.6754221388367729, + 0.6547842401500938, + 0.1275797373358349, + 0.016885553470919325 + ], + [ + 0.03564727954971857, + 0.005628517823639775, + 0.350844277673546, + 0.2776735459662289, + 0.23639774859287055, + 0.38649155722326456, + 0.03564727954971857, + 0.02626641651031895, + 0.11632270168855535, + 0.24577861163227016, + 0.13696060037523453, + 0.22138836772983114, + 0.1575984990619137, + 0.2026266416510319, + 0.07692307692307693, + 0.1350844277673546 + ], + [ + 0.30956848030018763, + 0.35647279549718575, + 0.849906191369606, + 0.7936210131332082, + 0.15947467166979362, + 0.26641651031894936, + 0.23639774859287055, + 0.3302063789868668, + 0.6716697936210131, + 0.45778611632270166, + 0.4709193245778612, + 0.7373358348968105, + 0.8067542213883677, + 0.8348968105065666, + 0.03189493433395872, + 0.09193245778611632 + ], + [ + 0.46153846153846156, + 0.4896810506566604, + 0.19887429643527205, + 0.30956848030018763, + 0.0900562851782364, + 0.13320825515947468, + 0.7185741088180112, + 0.1125703564727955, + 0.44652908067542213, + 0.11632270168855535, + 0.2964352720450281, + 0.075046904315197, + 0.28142589118198874, + 0.14071294559099437, + 0.2795497185741088, + 0.21575984990619138 + ], + [ + 0.7560975609756098, + 0.34709193245778613, + 0.23076923076923078, + 0.19136960600375236, + 0.4971857410881801, + 0.18198874296435272, + 0.8442776735459663, + 0.8048780487804879, + 0.05065666041275797, + 0.0450281425891182, + 0.15196998123827393, + 0.7542213883677298, + 0.0300187617260788, + 0.03189493433395872, + 0.5666041275797373, + 0.6022514071294559 + ], + [ + 0.28142589118198874, + 0.10881801125703565, + 0.14821763602251406, + 0.10318949343339587, + 0.0225140712945591, + 0.23639774859287055, + 0.28330206378986866, + 0.2045028142589118, + 0.11632270168855535, + 0.13696060037523453, + 0.19136960600375236, + 0.23827392120075047, + 0.3227016885553471, + 0.2945590994371482, + 0.8330206378986866, + 0.8198874296435272 + ], + [ + 0.09568480300187618, + 0.150093808630394, + 0.2551594746716698, + 0.13320825515947468, + 0.1575984990619137, + 0.18574108818011256, + 0.2776735459662289, + 0.16885553470919323, + 0.05065666041275797, + 0.16885553470919323, + 0.5909943714821764, + 0.18198874296435272, + 0.0675422138836773, + 0.04690431519699812, + 0.13696060037523453, + 0.15572232645403378 + ], + [ + 0.075046904315197, + 0.03189493433395872, + 0.07879924953095685, + 0.11819887429643527, + 0.06378986866791744, + 0.24390243902439024, + 0.2926829268292683, + 0.5703564727954972, + 0.24953095684803, + 0.31894934333958724, + 0.7429643527204502, + 0.5159474671669794, + 0.4915572232645403, + 0.549718574108818, + 0.8086303939962477, + 0.7523452157598499 + ], + [ + 0.36397748592870544, + 0.34896810506566606, + 0.275797373358349, + 0.23452157598499063, + 0.10694183864915573, + 0.04690431519699812, + 0.01876172607879925, + 0.024390243902439025, + 0.38461538461538464, + 0.30956848030018763, + 0.2626641651031895, + 0.24390243902439024, + 0.32082551594746717, + 0.45590994371482174, + 0.08818011257035648, + 0.08442776735459662 + ], + [ + 0.024390243902439025, + 0.024390243902439025, + 0.4146341463414634, + 0.7354596622889306, + 0.324577861163227, + 0.7354596622889306, + 0.20075046904315197, + 0.17823639774859287, + 0.14821763602251406, + 0.09380863039399624, + 0.4427767354596623, + 0.2964352720450281, + 0.0225140712945591, + 0.22326454033771106, + 0.06941838649155722, + 0.17073170731707318 + ], + [ + 0.0975609756097561, + 0.20825515947467166, + 0.47842401500938087, + 0.6041275797373359, + 0.49906191369606, + 0.7073170731707317, + 0.37335834896810505, + 0.7786116322701688, + 0.4521575984990619, + 0.5647279549718575, + 0.07879924953095685, + 0.07692307692307693, + 0.4596622889305816, + 0.474671669793621, + 0.01876172607879925, + 0.028142589118198873 + ], + [ + 0.09193245778611632, + 0.08067542213883677, + 0.2626641651031895, + 0.8555347091932458, + 0.4352720450281426, + 0.2776735459662289, + 0.38649155722326456, + 0.6116322701688556, + 0.32833020637898686, + 0.04127579737335835, + 0.6097560975609756, + 0.6322701688555347, + 0.41275797373358347, + 0.27392120075046905, + 0.7091932457786116, + 0.701688555347092 + ], + [ + 0.6360225140712945, + 0.6172607879924953, + 0.15572232645403378, + 0.0450281425891182, + 0.32833020637898686, + 0.0900562851782364, + 0.2795497185741088, + 0.26454033771106944, + 0.7692307692307693, + 0.7842401500938087, + 0.33583489681050654, + 0.43151969981238275, + 0.6228893058161351, + 0.4803001876172608, + 0.40337711069418386, + 0.4634146341463415 + ], + [ + 0.25328330206378985, + 0.3395872420262664, + 0.15196998123827393, + 0.06566604127579738, + 0.3452157598499062, + 0.2851782363977486, + 0.30956848030018763, + 0.7054409005628518, + 0.6979362101313321, + 0.701688555347092, + 0.1801125703564728, + 0.2401500938086304, + 0.6716697936210131, + 0.6228893058161351, + 0.18761726078799248, + 0.10881801125703565 + ], + [ + 0.5553470919324578, + 0.5647279549718575, + 0.0600375234521576, + 0.10881801125703565, + 0.6772983114446529, + 0.2682926829268293, + 0.5590994371482176, + 0.7091932457786116, + 0.05065666041275797, + 0.07317073170731707, + 0.5103189493433395, + 0.3789868667917448, + 0.275797373358349, + 0.16885553470919323, + 0.701688555347092, + 0.6923076923076923 + ], + [ + 0.043151969981238276, + 0.05065666041275797, + 0.054409005628517824, + 0.0600375234521576, + 0.46716697936210133, + 0.6904315196998124, + 0.626641651031895, + 0.6848030018761726, + 0.09943714821763602, + 0.09193245778611632, + 0.6566604127579737, + 0.6679174484052532, + 0.6697936210131332, + 0.6772983114446529, + 0.6979362101313321, + 0.6904315196998124 + ], + [ + 0.13696060037523453, + 0.09380863039399624, + 0.01876172607879925, + 0.08442776735459662, + 0.6923076923076923, + 0.701688555347092, + 0.6472795497185742, + 0.6772983114446529, + 0.32833020637898686, + 0.5534709193245778, + 0.6716697936210131, + 0.6941838649155723, + 0.6622889305816135, + 0.6566604127579737, + 0.6360225140712945, + 0.4521575984990619 + ], + [ + 0.49343339587242024, + 0.4709193245778612, + 0.6529080675422139, + 0.6378986866791745, + 0.6322701688555347, + 0.6041275797373359, + 0.23827392120075047, + 0.6322701688555347, + 0.6923076923076923, + 0.2926829268292683, + 0.03189493433395872, + 0.3058161350844278, + 0.07317073170731707, + 0.08630393996247655, + 0.6060037523452158, + 0.5590994371482176 + ], + [ + 0.1350844277673546, + 0.13883677298311445, + 0.08818011257035648, + 0.10694183864915573, + 0.04878048780487805, + 0.1350844277673546, + 0.09380863039399624, + 0.09380863039399624, + 0.1294559099437148, + 0.1125703564727955, + 0.13133208255159476, + 0.06941838649155722, + 0.075046904315197, + 0.10318949343339587, + 0.0975609756097561, + 0.09193245778611632 + ] + ], + "alignment_heads": [ + { + "layer": 20, + "head": 3, + "ts": 0.8555 + }, + { + "layer": 11, + "head": 2, + "ts": 0.8499 + }, + { + "layer": 13, + "head": 6, + "ts": 0.8443 + }, + { + "layer": 11, + "head": 13, + "ts": 0.8349 + }, + { + "layer": 14, + "head": 14, + "ts": 0.833 + }, + { + "layer": 6, + "head": 7, + "ts": 0.8199 + }, + { + "layer": 14, + "head": 15, + "ts": 0.8199 + }, + { + "layer": 6, + "head": 6, + "ts": 0.8086 + }, + { + "layer": 16, + "head": 14, + "ts": 0.8086 + }, + { + "layer": 6, + "head": 11, + "ts": 0.8068 + }, + { + "layer": 11, + "head": 12, + "ts": 0.8068 + }, + { + "layer": 13, + "head": 7, + "ts": 0.8049 + }, + { + "layer": 6, + "head": 13, + "ts": 0.7992 + }, + { + "layer": 3, + "head": 10, + "ts": 0.7974 + }, + { + "layer": 11, + "head": 3, + "ts": 0.7936 + }, + { + "layer": 6, + "head": 10, + "ts": 0.788 + }, + { + "layer": 21, + "head": 9, + "ts": 0.7842 + }, + { + "layer": 19, + "head": 7, + "ts": 0.7786 + }, + { + "layer": 21, + "head": 8, + "ts": 0.7692 + }, + { + "layer": 2, + "head": 11, + "ts": 0.7617 + }, + { + "layer": 13, + "head": 0, + "ts": 0.7561 + }, + { + "layer": 13, + "head": 11, + "ts": 0.7542 + }, + { + "layer": 16, + "head": 15, + "ts": 0.7523 + }, + { + "layer": 4, + "head": 13, + "ts": 0.7467 + }, + { + "layer": 16, + "head": 10, + "ts": 0.743 + }, + { + "layer": 11, + "head": 11, + "ts": 0.7373 + }, + { + "layer": 18, + "head": 3, + "ts": 0.7355 + }, + { + "layer": 18, + "head": 5, + "ts": 0.7355 + }, + { + "layer": 6, + "head": 12, + "ts": 0.7261 + }, + { + "layer": 12, + "head": 6, + "ts": 0.7186 + }, + { + "layer": 20, + "head": 14, + "ts": 0.7092 + }, + { + "layer": 23, + "head": 7, + "ts": 0.7092 + }, + { + "layer": 19, + "head": 5, + "ts": 0.7073 + }, + { + "layer": 22, + "head": 7, + "ts": 0.7054 + }, + { + "layer": 20, + "head": 15, + "ts": 0.7017 + }, + { + "layer": 22, + "head": 9, + "ts": 0.7017 + }, + { + "layer": 23, + "head": 14, + "ts": 0.7017 + }, + { + "layer": 25, + "head": 5, + "ts": 0.7017 + }, + { + "layer": 22, + "head": 8, + "ts": 0.6979 + }, + { + "layer": 24, + "head": 14, + "ts": 0.6979 + }, + { + "layer": 2, + "head": 10, + "ts": 0.6961 + }, + { + "layer": 25, + "head": 11, + "ts": 0.6942 + }, + { + "layer": 23, + "head": 15, + "ts": 0.6923 + }, + { + "layer": 25, + "head": 4, + "ts": 0.6923 + }, + { + "layer": 26, + "head": 8, + "ts": 0.6923 + }, + { + "layer": 8, + "head": 1, + "ts": 0.6904 + }, + { + "layer": 24, + "head": 5, + "ts": 0.6904 + }, + { + "layer": 24, + "head": 15, + "ts": 0.6904 + }, + { + "layer": 3, + "head": 11, + "ts": 0.6848 + }, + { + "layer": 24, + "head": 7, + "ts": 0.6848 + }, + { + "layer": 2, + "head": 3, + "ts": 0.6773 + }, + { + "layer": 23, + "head": 4, + "ts": 0.6773 + }, + { + "layer": 24, + "head": 13, + "ts": 0.6773 + }, + { + "layer": 25, + "head": 7, + "ts": 0.6773 + }, + { + "layer": 9, + "head": 12, + "ts": 0.6754 + }, + { + "layer": 11, + "head": 8, + "ts": 0.6717 + }, + { + "layer": 22, + "head": 12, + "ts": 0.6717 + }, + { + "layer": 25, + "head": 10, + "ts": 0.6717 + }, + { + "layer": 24, + "head": 12, + "ts": 0.6698 + }, + { + "layer": 24, + "head": 11, + "ts": 0.6679 + }, + { + "layer": 25, + "head": 12, + "ts": 0.6623 + }, + { + "layer": 24, + "head": 10, + "ts": 0.6567 + }, + { + "layer": 25, + "head": 13, + "ts": 0.6567 + }, + { + "layer": 9, + "head": 13, + "ts": 0.6548 + }, + { + "layer": 26, + "head": 2, + "ts": 0.6529 + }, + { + "layer": 25, + "head": 6, + "ts": 0.6473 + }, + { + "layer": 26, + "head": 3, + "ts": 0.6379 + }, + { + "layer": 21, + "head": 0, + "ts": 0.636 + }, + { + "layer": 25, + "head": 14, + "ts": 0.636 + }, + { + "layer": 20, + "head": 11, + "ts": 0.6323 + }, + { + "layer": 26, + "head": 4, + "ts": 0.6323 + }, + { + "layer": 26, + "head": 7, + "ts": 0.6323 + }, + { + "layer": 24, + "head": 6, + "ts": 0.6266 + }, + { + "layer": 21, + "head": 12, + "ts": 0.6229 + }, + { + "layer": 22, + "head": 13, + "ts": 0.6229 + }, + { + "layer": 21, + "head": 1, + "ts": 0.6173 + }, + { + "layer": 20, + "head": 7, + "ts": 0.6116 + }, + { + "layer": 20, + "head": 10, + "ts": 0.6098 + }, + { + "layer": 8, + "head": 4, + "ts": 0.606 + }, + { + "layer": 26, + "head": 14, + "ts": 0.606 + }, + { + "layer": 4, + "head": 12, + "ts": 0.6041 + }, + { + "layer": 19, + "head": 3, + "ts": 0.6041 + }, + { + "layer": 26, + "head": 5, + "ts": 0.6041 + }, + { + "layer": 13, + "head": 15, + "ts": 0.6023 + }, + { + "layer": 15, + "head": 10, + "ts": 0.591 + }, + { + "layer": 16, + "head": 7, + "ts": 0.5704 + }, + { + "layer": 13, + "head": 14, + "ts": 0.5666 + }, + { + "layer": 19, + "head": 9, + "ts": 0.5647 + }, + { + "layer": 23, + "head": 1, + "ts": 0.5647 + }, + { + "layer": 23, + "head": 6, + "ts": 0.5591 + }, + { + "layer": 26, + "head": 15, + "ts": 0.5591 + }, + { + "layer": 23, + "head": 0, + "ts": 0.5553 + }, + { + "layer": 25, + "head": 9, + "ts": 0.5535 + }, + { + "layer": 16, + "head": 13, + "ts": 0.5497 + }, + { + "layer": 5, + "head": 10, + "ts": 0.5235 + }, + { + "layer": 16, + "head": 11, + "ts": 0.5159 + }, + { + "layer": 23, + "head": 10, + "ts": 0.5103 + }, + { + "layer": 19, + "head": 4, + "ts": 0.4991 + }, + { + "layer": 13, + "head": 4, + "ts": 0.4972 + }, + { + "layer": 26, + "head": 0, + "ts": 0.4934 + }, + { + "layer": 16, + "head": 12, + "ts": 0.4916 + }, + { + "layer": 12, + "head": 1, + "ts": 0.4897 + }, + { + "layer": 21, + "head": 13, + "ts": 0.4803 + }, + { + "layer": 19, + "head": 2, + "ts": 0.4784 + }, + { + "layer": 19, + "head": 13, + "ts": 0.4747 + }, + { + "layer": 11, + "head": 10, + "ts": 0.4709 + }, + { + "layer": 26, + "head": 1, + "ts": 0.4709 + }, + { + "layer": 24, + "head": 4, + "ts": 0.4672 + }, + { + "layer": 21, + "head": 15, + "ts": 0.4634 + }, + { + "layer": 12, + "head": 0, + "ts": 0.4615 + }, + { + "layer": 19, + "head": 12, + "ts": 0.4597 + }, + { + "layer": 11, + "head": 9, + "ts": 0.4578 + }, + { + "layer": 17, + "head": 13, + "ts": 0.4559 + }, + { + "layer": 19, + "head": 8, + "ts": 0.4522 + }, + { + "layer": 25, + "head": 15, + "ts": 0.4522 + }, + { + "layer": 12, + "head": 8, + "ts": 0.4465 + }, + { + "layer": 18, + "head": 10, + "ts": 0.4428 + }, + { + "layer": 7, + "head": 0, + "ts": 0.4353 + }, + { + "layer": 20, + "head": 4, + "ts": 0.4353 + }, + { + "layer": 21, + "head": 11, + "ts": 0.4315 + }, + { + "layer": 18, + "head": 2, + "ts": 0.4146 + }, + { + "layer": 20, + "head": 12, + "ts": 0.4128 + }, + { + "layer": 21, + "head": 14, + "ts": 0.4034 + }, + { + "layer": 10, + "head": 5, + "ts": 0.3865 + }, + { + "layer": 20, + "head": 6, + "ts": 0.3865 + }, + { + "layer": 17, + "head": 8, + "ts": 0.3846 + }, + { + "layer": 7, + "head": 4, + "ts": 0.379 + }, + { + "layer": 23, + "head": 11, + "ts": 0.379 + }, + { + "layer": 9, + "head": 8, + "ts": 0.3771 + }, + { + "layer": 19, + "head": 6, + "ts": 0.3734 + }, + { + "layer": 17, + "head": 0, + "ts": 0.364 + }, + { + "layer": 5, + "head": 15, + "ts": 0.3621 + }, + { + "layer": 8, + "head": 7, + "ts": 0.3583 + }, + { + "layer": 8, + "head": 14, + "ts": 0.3583 + }, + { + "layer": 11, + "head": 1, + "ts": 0.3565 + }, + { + "layer": 10, + "head": 2, + "ts": 0.3508 + }, + { + "layer": 5, + "head": 5, + "ts": 0.349 + }, + { + "layer": 17, + "head": 1, + "ts": 0.349 + }, + { + "layer": 4, + "head": 10, + "ts": 0.3471 + }, + { + "layer": 13, + "head": 1, + "ts": 0.3471 + }, + { + "layer": 22, + "head": 4, + "ts": 0.3452 + }, + { + "layer": 22, + "head": 1, + "ts": 0.3396 + }, + { + "layer": 21, + "head": 10, + "ts": 0.3358 + }, + { + "layer": 11, + "head": 7, + "ts": 0.3302 + }, + { + "layer": 20, + "head": 8, + "ts": 0.3283 + }, + { + "layer": 21, + "head": 4, + "ts": 0.3283 + }, + { + "layer": 25, + "head": 8, + "ts": 0.3283 + }, + { + "layer": 3, + "head": 4, + "ts": 0.3246 + }, + { + "layer": 8, + "head": 2, + "ts": 0.3246 + }, + { + "layer": 9, + "head": 9, + "ts": 0.3246 + }, + { + "layer": 18, + "head": 4, + "ts": 0.3246 + }, + { + "layer": 14, + "head": 12, + "ts": 0.3227 + }, + { + "layer": 4, + "head": 15, + "ts": 0.3208 + }, + { + "layer": 17, + "head": 12, + "ts": 0.3208 + }, + { + "layer": 16, + "head": 9, + "ts": 0.3189 + }, + { + "layer": 11, + "head": 0, + "ts": 0.3096 + }, + { + "layer": 12, + "head": 3, + "ts": 0.3096 + }, + { + "layer": 17, + "head": 9, + "ts": 0.3096 + }, + { + "layer": 22, + "head": 6, + "ts": 0.3096 + }, + { + "layer": 6, + "head": 9, + "ts": 0.3077 + }, + { + "layer": 8, + "head": 5, + "ts": 0.3077 + }, + { + "layer": 26, + "head": 11, + "ts": 0.3058 + }, + { + "layer": 8, + "head": 6, + "ts": 0.3039 + }, + { + "layer": 12, + "head": 10, + "ts": 0.2964 + }, + { + "layer": 18, + "head": 11, + "ts": 0.2964 + }, + { + "layer": 14, + "head": 13, + "ts": 0.2946 + }, + { + "layer": 16, + "head": 6, + "ts": 0.2927 + }, + { + "layer": 26, + "head": 9, + "ts": 0.2927 + }, + { + "layer": 1, + "head": 8, + "ts": 0.2871 + }, + { + "layer": 22, + "head": 5, + "ts": 0.2852 + }, + { + "layer": 14, + "head": 6, + "ts": 0.2833 + }, + { + "layer": 12, + "head": 12, + "ts": 0.2814 + }, + { + "layer": 14, + "head": 0, + "ts": 0.2814 + }, + { + "layer": 6, + "head": 14, + "ts": 0.2795 + }, + { + "layer": 12, + "head": 14, + "ts": 0.2795 + }, + { + "layer": 21, + "head": 6, + "ts": 0.2795 + }, + { + "layer": 3, + "head": 13, + "ts": 0.2777 + }, + { + "layer": 10, + "head": 3, + "ts": 0.2777 + }, + { + "layer": 15, + "head": 6, + "ts": 0.2777 + }, + { + "layer": 20, + "head": 5, + "ts": 0.2777 + }, + { + "layer": 6, + "head": 0, + "ts": 0.2758 + }, + { + "layer": 17, + "head": 2, + "ts": 0.2758 + }, + { + "layer": 23, + "head": 12, + "ts": 0.2758 + }, + { + "layer": 20, + "head": 13, + "ts": 0.2739 + }, + { + "layer": 9, + "head": 7, + "ts": 0.2702 + }, + { + "layer": 23, + "head": 5, + "ts": 0.2683 + }, + { + "layer": 11, + "head": 5, + "ts": 0.2664 + }, + { + "layer": 3, + "head": 14, + "ts": 0.2645 + }, + { + "layer": 7, + "head": 5, + "ts": 0.2645 + }, + { + "layer": 21, + "head": 7, + "ts": 0.2645 + }, + { + "layer": 17, + "head": 10, + "ts": 0.2627 + }, + { + "layer": 20, + "head": 2, + "ts": 0.2627 + }, + { + "layer": 6, + "head": 1, + "ts": 0.2552 + }, + { + "layer": 9, + "head": 0, + "ts": 0.2552 + }, + { + "layer": 9, + "head": 5, + "ts": 0.2552 + }, + { + "layer": 15, + "head": 2, + "ts": 0.2552 + }, + { + "layer": 9, + "head": 2, + "ts": 0.2533 + }, + { + "layer": 22, + "head": 0, + "ts": 0.2533 + }, + { + "layer": 16, + "head": 8, + "ts": 0.2495 + }, + { + "layer": 9, + "head": 4, + "ts": 0.2458 + }, + { + "layer": 10, + "head": 9, + "ts": 0.2458 + }, + { + "layer": 16, + "head": 5, + "ts": 0.2439 + }, + { + "layer": 17, + "head": 11, + "ts": 0.2439 + }, + { + "layer": 2, + "head": 14, + "ts": 0.242 + }, + { + "layer": 4, + "head": 11, + "ts": 0.242 + }, + { + "layer": 22, + "head": 11, + "ts": 0.2402 + }, + { + "layer": 5, + "head": 4, + "ts": 0.2383 + }, + { + "layer": 14, + "head": 11, + "ts": 0.2383 + }, + { + "layer": 26, + "head": 6, + "ts": 0.2383 + }, + { + "layer": 10, + "head": 4, + "ts": 0.2364 + }, + { + "layer": 11, + "head": 6, + "ts": 0.2364 + }, + { + "layer": 14, + "head": 5, + "ts": 0.2364 + }, + { + "layer": 17, + "head": 3, + "ts": 0.2345 + }, + { + "layer": 7, + "head": 6, + "ts": 0.2308 + }, + { + "layer": 13, + "head": 2, + "ts": 0.2308 + }, + { + "layer": 6, + "head": 15, + "ts": 0.2233 + }, + { + "layer": 18, + "head": 13, + "ts": 0.2233 + }, + { + "layer": 10, + "head": 11, + "ts": 0.2214 + }, + { + "layer": 12, + "head": 15, + "ts": 0.2158 + }, + { + "layer": 6, + "head": 4, + "ts": 0.212 + }, + { + "layer": 7, + "head": 13, + "ts": 0.212 + }, + { + "layer": 19, + "head": 1, + "ts": 0.2083 + }, + { + "layer": 14, + "head": 7, + "ts": 0.2045 + }, + { + "layer": 10, + "head": 13, + "ts": 0.2026 + }, + { + "layer": 18, + "head": 6, + "ts": 0.2008 + }, + { + "layer": 5, + "head": 8, + "ts": 0.1989 + }, + { + "layer": 12, + "head": 2, + "ts": 0.1989 + }, + { + "layer": 2, + "head": 6, + "ts": 0.1914 + }, + { + "layer": 13, + "head": 3, + "ts": 0.1914 + }, + { + "layer": 14, + "head": 10, + "ts": 0.1914 + }, + { + "layer": 3, + "head": 15, + "ts": 0.1876 + }, + { + "layer": 22, + "head": 14, + "ts": 0.1876 + }, + { + "layer": 15, + "head": 5, + "ts": 0.1857 + }, + { + "layer": 6, + "head": 5, + "ts": 0.182 + }, + { + "layer": 9, + "head": 10, + "ts": 0.182 + }, + { + "layer": 13, + "head": 5, + "ts": 0.182 + }, + { + "layer": 15, + "head": 11, + "ts": 0.182 + }, + { + "layer": 22, + "head": 10, + "ts": 0.1801 + }, + { + "layer": 3, + "head": 3, + "ts": 0.1782 + }, + { + "layer": 18, + "head": 7, + "ts": 0.1782 + }, + { + "layer": 3, + "head": 2, + "ts": 0.1764 + }, + { + "layer": 4, + "head": 4, + "ts": 0.1745 + }, + { + "layer": 8, + "head": 13, + "ts": 0.1726 + }, + { + "layer": 7, + "head": 12, + "ts": 0.1707 + }, + { + "layer": 18, + "head": 15, + "ts": 0.1707 + }, + { + "layer": 15, + "head": 7, + "ts": 0.1689 + }, + { + "layer": 15, + "head": 9, + "ts": 0.1689 + }, + { + "layer": 23, + "head": 13, + "ts": 0.1689 + }, + { + "layer": 8, + "head": 0, + "ts": 0.1651 + }, + { + "layer": 8, + "head": 9, + "ts": 0.1595 + }, + { + "layer": 11, + "head": 4, + "ts": 0.1595 + }, + { + "layer": 0, + "head": 3, + "ts": 0.1576 + }, + { + "layer": 10, + "head": 12, + "ts": 0.1576 + }, + { + "layer": 15, + "head": 4, + "ts": 0.1576 + }, + { + "layer": 2, + "head": 8, + "ts": 0.1557 + }, + { + "layer": 15, + "head": 15, + "ts": 0.1557 + }, + { + "layer": 21, + "head": 2, + "ts": 0.1557 + }, + { + "layer": 5, + "head": 1, + "ts": 0.1538 + }, + { + "layer": 9, + "head": 3, + "ts": 0.1538 + }, + { + "layer": 4, + "head": 8, + "ts": 0.152 + }, + { + "layer": 13, + "head": 10, + "ts": 0.152 + }, + { + "layer": 22, + "head": 2, + "ts": 0.152 + }, + { + "layer": 1, + "head": 15, + "ts": 0.1501 + }, + { + "layer": 7, + "head": 11, + "ts": 0.1501 + }, + { + "layer": 15, + "head": 1, + "ts": 0.1501 + }, + { + "layer": 14, + "head": 2, + "ts": 0.1482 + }, + { + "layer": 18, + "head": 8, + "ts": 0.1482 + }, + { + "layer": 1, + "head": 3, + "ts": 0.1463 + }, + { + "layer": 5, + "head": 11, + "ts": 0.1463 + }, + { + "layer": 5, + "head": 14, + "ts": 0.1463 + }, + { + "layer": 8, + "head": 11, + "ts": 0.1445 + }, + { + "layer": 7, + "head": 14, + "ts": 0.1426 + }, + { + "layer": 8, + "head": 10, + "ts": 0.1407 + }, + { + "layer": 12, + "head": 13, + "ts": 0.1407 + }, + { + "layer": 3, + "head": 8, + "ts": 0.1388 + }, + { + "layer": 27, + "head": 1, + "ts": 0.1388 + }, + { + "layer": 10, + "head": 10, + "ts": 0.137 + }, + { + "layer": 14, + "head": 9, + "ts": 0.137 + }, + { + "layer": 15, + "head": 14, + "ts": 0.137 + }, + { + "layer": 25, + "head": 0, + "ts": 0.137 + }, + { + "layer": 10, + "head": 15, + "ts": 0.1351 + }, + { + "layer": 27, + "head": 0, + "ts": 0.1351 + }, + { + "layer": 27, + "head": 5, + "ts": 0.1351 + }, + { + "layer": 12, + "head": 5, + "ts": 0.1332 + }, + { + "layer": 15, + "head": 3, + "ts": 0.1332 + }, + { + "layer": 27, + "head": 10, + "ts": 0.1313 + }, + { + "layer": 27, + "head": 8, + "ts": 0.1295 + }, + { + "layer": 9, + "head": 14, + "ts": 0.1276 + }, + { + "layer": 5, + "head": 0, + "ts": 0.122 + }, + { + "layer": 0, + "head": 2, + "ts": 0.1182 + }, + { + "layer": 16, + "head": 3, + "ts": 0.1182 + }, + { + "layer": 4, + "head": 9, + "ts": 0.1163 + }, + { + "layer": 10, + "head": 8, + "ts": 0.1163 + }, + { + "layer": 12, + "head": 9, + "ts": 0.1163 + }, + { + "layer": 14, + "head": 8, + "ts": 0.1163 + }, + { + "layer": 6, + "head": 3, + "ts": 0.1144 + }, + { + "layer": 7, + "head": 15, + "ts": 0.1126 + }, + { + "layer": 12, + "head": 7, + "ts": 0.1126 + }, + { + "layer": 27, + "head": 9, + "ts": 0.1126 + }, + { + "layer": 3, + "head": 6, + "ts": 0.1107 + }, + { + "layer": 8, + "head": 12, + "ts": 0.1107 + }, + { + "layer": 2, + "head": 2, + "ts": 0.1088 + }, + { + "layer": 4, + "head": 2, + "ts": 0.1088 + }, + { + "layer": 5, + "head": 2, + "ts": 0.1088 + }, + { + "layer": 5, + "head": 7, + "ts": 0.1088 + }, + { + "layer": 14, + "head": 1, + "ts": 0.1088 + }, + { + "layer": 22, + "head": 15, + "ts": 0.1088 + }, + { + "layer": 23, + "head": 3, + "ts": 0.1088 + }, + { + "layer": 9, + "head": 11, + "ts": 0.1069 + }, + { + "layer": 17, + "head": 4, + "ts": 0.1069 + }, + { + "layer": 27, + "head": 3, + "ts": 0.1069 + }, + { + "layer": 5, + "head": 9, + "ts": 0.1051 + }, + { + "layer": 14, + "head": 3, + "ts": 0.1032 + }, + { + "layer": 27, + "head": 13, + "ts": 0.1032 + } + ], + "alignment_heads_compact": [ + [ + 20, + 3 + ], + [ + 11, + 2 + ], + [ + 13, + 6 + ], + [ + 11, + 13 + ], + [ + 14, + 14 + ], + [ + 6, + 7 + ], + [ + 14, + 15 + ], + [ + 6, + 6 + ], + [ + 16, + 14 + ], + [ + 6, + 11 + ], + [ + 11, + 12 + ], + [ + 13, + 7 + ], + [ + 6, + 13 + ], + [ + 3, + 10 + ], + [ + 11, + 3 + ], + [ + 6, + 10 + ], + [ + 21, + 9 + ], + [ + 19, + 7 + ], + [ + 21, + 8 + ], + [ + 2, + 11 + ], + [ + 13, + 0 + ], + [ + 13, + 11 + ], + [ + 16, + 15 + ], + [ + 4, + 13 + ], + [ + 16, + 10 + ], + [ + 11, + 11 + ], + [ + 18, + 3 + ], + [ + 18, + 5 + ], + [ + 6, + 12 + ], + [ + 12, + 6 + ], + [ + 20, + 14 + ], + [ + 23, + 7 + ], + [ + 19, + 5 + ], + [ + 22, + 7 + ], + [ + 20, + 15 + ], + [ + 22, + 9 + ], + [ + 23, + 14 + ], + [ + 25, + 5 + ], + [ + 22, + 8 + ], + [ + 24, + 14 + ], + [ + 2, + 10 + ], + [ + 25, + 11 + ], + [ + 23, + 15 + ], + [ + 25, + 4 + ], + [ + 26, + 8 + ], + [ + 8, + 1 + ], + [ + 24, + 5 + ], + [ + 24, + 15 + ], + [ + 3, + 11 + ], + [ + 24, + 7 + ], + [ + 2, + 3 + ], + [ + 23, + 4 + ], + [ + 24, + 13 + ], + [ + 25, + 7 + ], + [ + 9, + 12 + ], + [ + 11, + 8 + ], + [ + 22, + 12 + ], + [ + 25, + 10 + ], + [ + 24, + 12 + ], + [ + 24, + 11 + ], + [ + 25, + 12 + ], + [ + 24, + 10 + ], + [ + 25, + 13 + ], + [ + 9, + 13 + ], + [ + 26, + 2 + ], + [ + 25, + 6 + ], + [ + 26, + 3 + ], + [ + 21, + 0 + ], + [ + 25, + 14 + ], + [ + 20, + 11 + ], + [ + 26, + 4 + ], + [ + 26, + 7 + ], + [ + 24, + 6 + ], + [ + 21, + 12 + ], + [ + 22, + 13 + ], + [ + 21, + 1 + ], + [ + 20, + 7 + ], + [ + 20, + 10 + ], + [ + 8, + 4 + ], + [ + 26, + 14 + ], + [ + 4, + 12 + ], + [ + 19, + 3 + ], + [ + 26, + 5 + ], + [ + 13, + 15 + ], + [ + 15, + 10 + ], + [ + 16, + 7 + ], + [ + 13, + 14 + ], + [ + 19, + 9 + ], + [ + 23, + 1 + ], + [ + 23, + 6 + ], + [ + 26, + 15 + ], + [ + 23, + 0 + ], + [ + 25, + 9 + ], + [ + 16, + 13 + ], + [ + 5, + 10 + ], + [ + 16, + 11 + ], + [ + 23, + 10 + ], + [ + 19, + 4 + ], + [ + 13, + 4 + ], + [ + 26, + 0 + ], + [ + 16, + 12 + ], + [ + 12, + 1 + ], + [ + 21, + 13 + ], + [ + 19, + 2 + ], + [ + 19, + 13 + ], + [ + 11, + 10 + ], + [ + 26, + 1 + ], + [ + 24, + 4 + ], + [ + 21, + 15 + ], + [ + 12, + 0 + ], + [ + 19, + 12 + ], + [ + 11, + 9 + ], + [ + 17, + 13 + ], + [ + 19, + 8 + ], + [ + 25, + 15 + ], + [ + 12, + 8 + ], + [ + 18, + 10 + ], + [ + 7, + 0 + ], + [ + 20, + 4 + ], + [ + 21, + 11 + ], + [ + 18, + 2 + ], + [ + 20, + 12 + ], + [ + 21, + 14 + ], + [ + 10, + 5 + ], + [ + 20, + 6 + ], + [ + 17, + 8 + ], + [ + 7, + 4 + ], + [ + 23, + 11 + ], + [ + 9, + 8 + ], + [ + 19, + 6 + ], + [ + 17, + 0 + ], + [ + 5, + 15 + ], + [ + 8, + 7 + ], + [ + 8, + 14 + ], + [ + 11, + 1 + ], + [ + 10, + 2 + ], + [ + 5, + 5 + ], + [ + 17, + 1 + ], + [ + 4, + 10 + ], + [ + 13, + 1 + ], + [ + 22, + 4 + ], + [ + 22, + 1 + ], + [ + 21, + 10 + ], + [ + 11, + 7 + ], + [ + 20, + 8 + ], + [ + 21, + 4 + ], + [ + 25, + 8 + ], + [ + 3, + 4 + ], + [ + 8, + 2 + ], + [ + 9, + 9 + ], + [ + 18, + 4 + ], + [ + 14, + 12 + ], + [ + 4, + 15 + ], + [ + 17, + 12 + ], + [ + 16, + 9 + ], + [ + 11, + 0 + ], + [ + 12, + 3 + ], + [ + 17, + 9 + ], + [ + 22, + 6 + ], + [ + 6, + 9 + ], + [ + 8, + 5 + ], + [ + 26, + 11 + ], + [ + 8, + 6 + ], + [ + 12, + 10 + ], + [ + 18, + 11 + ], + [ + 14, + 13 + ], + [ + 16, + 6 + ], + [ + 26, + 9 + ], + [ + 1, + 8 + ], + [ + 22, + 5 + ], + [ + 14, + 6 + ], + [ + 12, + 12 + ], + [ + 14, + 0 + ], + [ + 6, + 14 + ], + [ + 12, + 14 + ], + [ + 21, + 6 + ], + [ + 3, + 13 + ], + [ + 10, + 3 + ], + [ + 15, + 6 + ], + [ + 20, + 5 + ], + [ + 6, + 0 + ], + [ + 17, + 2 + ], + [ + 23, + 12 + ], + [ + 20, + 13 + ], + [ + 9, + 7 + ], + [ + 23, + 5 + ], + [ + 11, + 5 + ], + [ + 3, + 14 + ], + [ + 7, + 5 + ], + [ + 21, + 7 + ], + [ + 17, + 10 + ], + [ + 20, + 2 + ], + [ + 6, + 1 + ], + [ + 9, + 0 + ], + [ + 9, + 5 + ], + [ + 15, + 2 + ], + [ + 9, + 2 + ], + [ + 22, + 0 + ], + [ + 16, + 8 + ], + [ + 9, + 4 + ], + [ + 10, + 9 + ], + [ + 16, + 5 + ], + [ + 17, + 11 + ], + [ + 2, + 14 + ], + [ + 4, + 11 + ], + [ + 22, + 11 + ], + [ + 5, + 4 + ], + [ + 14, + 11 + ], + [ + 26, + 6 + ], + [ + 10, + 4 + ], + [ + 11, + 6 + ], + [ + 14, + 5 + ], + [ + 17, + 3 + ], + [ + 7, + 6 + ], + [ + 13, + 2 + ], + [ + 6, + 15 + ], + [ + 18, + 13 + ], + [ + 10, + 11 + ], + [ + 12, + 15 + ], + [ + 6, + 4 + ], + [ + 7, + 13 + ], + [ + 19, + 1 + ], + [ + 14, + 7 + ], + [ + 10, + 13 + ], + [ + 18, + 6 + ], + [ + 5, + 8 + ], + [ + 12, + 2 + ], + [ + 2, + 6 + ], + [ + 13, + 3 + ], + [ + 14, + 10 + ], + [ + 3, + 15 + ], + [ + 22, + 14 + ], + [ + 15, + 5 + ], + [ + 6, + 5 + ], + [ + 9, + 10 + ], + [ + 13, + 5 + ], + [ + 15, + 11 + ], + [ + 22, + 10 + ], + [ + 3, + 3 + ], + [ + 18, + 7 + ], + [ + 3, + 2 + ], + [ + 4, + 4 + ], + [ + 8, + 13 + ], + [ + 7, + 12 + ], + [ + 18, + 15 + ], + [ + 15, + 7 + ], + [ + 15, + 9 + ], + [ + 23, + 13 + ], + [ + 8, + 0 + ], + [ + 8, + 9 + ], + [ + 11, + 4 + ], + [ + 0, + 3 + ], + [ + 10, + 12 + ], + [ + 15, + 4 + ], + [ + 2, + 8 + ], + [ + 15, + 15 + ], + [ + 21, + 2 + ], + [ + 5, + 1 + ], + [ + 9, + 3 + ], + [ + 4, + 8 + ], + [ + 13, + 10 + ], + [ + 22, + 2 + ], + [ + 1, + 15 + ], + [ + 7, + 11 + ], + [ + 15, + 1 + ], + [ + 14, + 2 + ], + [ + 18, + 8 + ], + [ + 1, + 3 + ], + [ + 5, + 11 + ], + [ + 5, + 14 + ], + [ + 8, + 11 + ], + [ + 7, + 14 + ], + [ + 8, + 10 + ], + [ + 12, + 13 + ], + [ + 3, + 8 + ], + [ + 27, + 1 + ], + [ + 10, + 10 + ], + [ + 14, + 9 + ], + [ + 15, + 14 + ], + [ + 25, + 0 + ], + [ + 10, + 15 + ], + [ + 27, + 0 + ], + [ + 27, + 5 + ], + [ + 12, + 5 + ], + [ + 15, + 3 + ], + [ + 27, + 10 + ], + [ + 27, + 8 + ], + [ + 9, + 14 + ], + [ + 5, + 0 + ], + [ + 0, + 2 + ], + [ + 16, + 3 + ], + [ + 4, + 9 + ], + [ + 10, + 8 + ], + [ + 12, + 9 + ], + [ + 14, + 8 + ], + [ + 6, + 3 + ], + [ + 7, + 15 + ], + [ + 12, + 7 + ], + [ + 27, + 9 + ], + [ + 3, + 6 + ], + [ + 8, + 12 + ], + [ + 2, + 2 + ], + [ + 4, + 2 + ], + [ + 5, + 2 + ], + [ + 5, + 7 + ], + [ + 14, + 1 + ], + [ + 22, + 15 + ], + [ + 23, + 3 + ], + [ + 9, + 11 + ], + [ + 17, + 4 + ], + [ + 27, + 3 + ], + [ + 5, + 9 + ], + [ + 14, + 3 + ], + [ + 27, + 13 + ] + ] +} \ No newline at end of file diff --git a/whisperlivekit/cascade_bridge.py b/whisperlivekit/cascade_bridge.py new file mode 100644 index 0000000..def7ca9 --- /dev/null +++ b/whisperlivekit/cascade_bridge.py @@ -0,0 +1,116 @@ +""" +Bridge between WhisperLiveKit STT and IWSLT26 MT pipeline. + +Converts streaming ASRToken output from SimulStreaming into the JSONL +format expected by the AlignAtt MT agent (iwslt26-sst). + +Output format (one JSON per line): + {"text": "word or phrase", "emission_time": 1.234, "is_final": false, "speech_time": 1.0} + +Where: + - text: the emitted word/phrase + - emission_time: wall-clock time when the word was emitted (for compute-aware eval) + - speech_time: timestamp in the audio (for compute-unaware eval) + - is_final: whether this is the last word of a segment/silence boundary +""" + +import json +import time +from typing import List, TextIO + +from whisperlivekit.timed_objects import ASRToken + + +class CascadeBridge: + """Converts ASRToken stream to JSONL for the MT agent.""" + + def __init__(self, output_file: TextIO = None): + self.output_file = output_file + self.start_time = time.time() + self.entries: List[dict] = [] + + def emit_tokens(self, tokens: List[ASRToken], is_final: bool = False): + """Emit a batch of tokens from the STT.""" + wall_clock = time.time() - self.start_time + + for i, token in enumerate(tokens): + entry = { + "text": token.text.strip(), + "emission_time": round(wall_clock, 3), + "speech_time": round(token.start, 3), + "is_final": is_final and (i == len(tokens) - 1), + } + self.entries.append(entry) + if self.output_file: + self.output_file.write(json.dumps(entry) + "\n") + self.output_file.flush() + + def get_entries(self) -> List[dict]: + return self.entries + + def get_text(self) -> str: + """Get the full transcribed text.""" + return " ".join(e["text"] for e in self.entries if e["text"]) + + def save(self, path: str): + """Save all entries to a JSONL file.""" + with open(path, "w") as f: + for entry in self.entries: + f.write(json.dumps(entry) + "\n") + + +def run_stt_to_jsonl( + audio_path: str, + output_path: str, + model_id: str = "Qwen/Qwen3-ASR-0.6B", + alignment_heads_path: str = None, + border_fraction: float = 0.20, + language: str = "en", + chunk_sec: float = 1.0, +): + """Run STT on an audio file and save JSONL output for the MT agent. + + This is the main entry point for the cascade: audio file → JSONL. + """ + import wave + import numpy as np + from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR, Qwen3SimulKVOnlineProcessor + + # Load audio + with wave.open(audio_path, 'r') as wf: + audio = np.frombuffer( + wf.readframes(wf.getnframes()), dtype=np.int16 + ).astype(np.float32) / 32768.0 + + # Initialize STT + asr = Qwen3SimulKVASR( + model_dir=model_id, + lan=language, + alignment_heads_path=alignment_heads_path, + border_fraction=border_fraction, + ) + proc = Qwen3SimulKVOnlineProcessor(asr) + bridge = CascadeBridge() + + # Stream audio in chunks + chunk_samples = int(chunk_sec * 16000) + offset = 0 + stream_time = 0.0 + + while offset < len(audio): + chunk = audio[offset:offset + chunk_samples] + stream_time += len(chunk) / 16000 + proc.insert_audio_chunk(chunk, stream_time) + words, _ = proc.process_iter(is_last=False) + if words: + bridge.emit_tokens(words, is_final=False) + offset += chunk_samples + + # Final flush + final_words, _ = proc.finish() + if final_words: + bridge.emit_tokens(final_words, is_final=True) + + # Save + bridge.save(output_path) + return bridge diff --git a/whisperlivekit/core.py b/whisperlivekit/core.py index 6a67c55..2f3558d 100644 --- a/whisperlivekit/core.py +++ b/whisperlivekit/core.py @@ -126,6 +126,15 @@ class TranscriptionEngine: self.tokenizer = None self.asr = Qwen3MLXASR(**transcription_common_params) logger.info("Using Qwen3 MLX native backend") + elif config.backend == "qwen3-simul-kv": + from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVASR + self.tokenizer = None + self.asr = Qwen3SimulKVASR( + **transcription_common_params, + alignment_heads_path=config.custom_alignment_heads, + border_fraction=getattr(config, 'border_fraction', 0.25), + ) + logger.info("Using Qwen3-ASR backend with SimulStreaming+KV policy") elif config.backend == "qwen3-simul": from whisperlivekit.qwen3_simul import Qwen3SimulStreamingASR self.tokenizer = None @@ -235,6 +244,9 @@ def online_factory(args, asr, language=None): if backend == "vllm-realtime": from whisperlivekit.vllm_realtime import VLLMRealtimeOnlineProcessor return VLLMRealtimeOnlineProcessor(asr) + if backend == "qwen3-simul-kv": + from whisperlivekit.qwen3_simul_kv import Qwen3SimulKVOnlineProcessor + return Qwen3SimulKVOnlineProcessor(asr) if backend == "qwen3-mlx": from whisperlivekit.qwen3_mlx_asr import Qwen3MLXOnlineProcessor return Qwen3MLXOnlineProcessor(asr) diff --git a/whisperlivekit/qwen3_simul_kv.py b/whisperlivekit/qwen3_simul_kv.py new file mode 100644 index 0000000..276178a --- /dev/null +++ b/whisperlivekit/qwen3_simul_kv.py @@ -0,0 +1,787 @@ +""" +Qwen3-ASR SimulStreaming with KV cache reuse. + +This is an optimized version of qwen3_simul.py that reuses the KV cache +across inference calls, avoiding redundant prefill of prompt + old audio. + +Architecture: + 1. First call: full prefill (prompt + audio tokens), greedy decode with + alignment-head stopping, save KV cache + generated tokens + 2. Subsequent calls: invalidate KV for old audio suffix, prefill only + new audio tokens, continue decoding from saved state + 3. Audio encoder caching: reuse embeddings for stable attention windows + +This gives ~3-5x speedup over the original generate()-based approach. +""" + +import json +import logging +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional, Tuple + +import numpy as np +import torch +from transformers import DynamicCache + +from whisperlivekit.timed_objects import ASRToken, ChangeSpeaker, Transcript + +logger = logging.getLogger(__name__) + +SAMPLE_RATE = 16000 + + +@dataclass +class Qwen3SimulKVConfig: + """Configuration for Qwen3 SimulStreaming with KV cache.""" + model_id: str = "Qwen/Qwen3-ASR-1.7B" + alignment_heads_path: Optional[str] = None + language: str = "auto" + border_fraction: float = 0.20 + rewind_fraction: float = 0.12 + audio_min_len: float = 0.5 + audio_max_len: float = 20.0 + max_context_tokens: int = 30 + init_prompt: Optional[str] = None + max_alignment_heads: int = 10 + + +@dataclass +class _AudioEmbedCache: + """Cache for audio encoder outputs.""" + encoded_samples: int = 0 + embeddings: Optional[torch.Tensor] = None + encoded_mel_frames: int = 0 + stable_tokens: int = 0 + + def reset(self): + self.encoded_samples = 0 + self.embeddings = None + self.encoded_mel_frames = 0 + self.stable_tokens = 0 + + +@dataclass +class Qwen3SimulKVState: + """Per-session mutable state with KV cache.""" + # Audio + audio_buffer: np.ndarray = field( + default_factory=lambda: np.array([], dtype=np.float32) + ) + cumulative_time_offset: float = 0.0 + global_time_offset: float = 0.0 + speaker: int = -1 + + # KV cache state + kv_cache: Optional[DynamicCache] = None + kv_seq_len: int = 0 # sequence length when KV was saved + prompt_token_count: int = 0 # tokens before audio (system prompt etc) + audio_token_count: int = 0 # audio tokens in the cached KV + generated_token_ids: List[int] = field(default_factory=list) + + # Alignment tracking + last_attend_frame: int = -15 + committed_text: str = "" + committed_word_count: int = 0 + committed_token_ids: List[int] = field(default_factory=list) + + # Tracking + first_timestamp: Optional[float] = None + detected_language: Optional[str] = None + last_infer_samples: int = 0 + + # Audio embedding cache + audio_cache: _AudioEmbedCache = field(default_factory=_AudioEmbedCache) + + def reset_kv(self): + """Reset KV cache (e.g., when audio is trimmed from front).""" + self.kv_cache = None + self.kv_seq_len = 0 + self.prompt_token_count = 0 + self.audio_token_count = 0 + self.generated_token_ids = [] + + +class Qwen3SimulKVASR: + """ + Shared backend for Qwen3-ASR SimulStreaming with KV cache reuse. + """ + + sep = "" + + def __init__( + self, + model_size: str = None, + model_dir: str = None, + lan: str = "auto", + alignment_heads_path: Optional[str] = None, + border_fraction: float = 0.15, + min_chunk_size: float = 0.1, + warmup_file: Optional[str] = None, + model_cache_dir: Optional[str] = None, + model_path: Optional[str] = None, + lora_path: Optional[str] = None, + direct_english_translation: bool = False, + **kwargs, + ): + self.transcribe_kargs = {} + self.original_language = None if lan == "auto" else lan + self.warmup_file = warmup_file + + self.cfg = Qwen3SimulKVConfig( + language=lan, + alignment_heads_path=alignment_heads_path, + border_fraction=border_fraction, + ) + + self._load_model(model_size, model_dir, model_cache_dir, model_path) + self.alignment_heads = self._load_alignment_heads(alignment_heads_path) + + # Pre-compute heads by layer for efficient hook installation + self.heads_by_layer = {} + for layer_idx, head_idx in self.alignment_heads: + self.heads_by_layer.setdefault(layer_idx, []).append(head_idx) + + if warmup_file: + from whisperlivekit.warmup import load_file + audio = load_file(warmup_file) + if audio is not None: + self._warmup(audio) + + def _load_model(self, model_size, model_dir, model_cache_dir, model_path): + from whisperlivekit.qwen3_asr import QWEN3_MODEL_MAPPING, _patch_transformers_compat + _patch_transformers_compat() + + from qwen_asr.core.transformers_backend import ( + Qwen3ASRConfig, Qwen3ASRForConditionalGeneration, Qwen3ASRProcessor, + ) + from transformers import AutoConfig, AutoModel, AutoProcessor + + AutoConfig.register("qwen3_asr", Qwen3ASRConfig) + AutoModel.register(Qwen3ASRConfig, Qwen3ASRForConditionalGeneration) + AutoProcessor.register(Qwen3ASRConfig, Qwen3ASRProcessor) + + if model_dir: + model_id = model_dir + elif model_path: + model_id = model_path + elif model_size: + model_id = QWEN3_MODEL_MAPPING.get(model_size.lower(), model_size) + else: + model_id = "Qwen/Qwen3-ASR-1.7B" + + if torch.cuda.is_available(): + dtype, device = torch.bfloat16, "cuda:0" + else: + dtype, device = torch.float32, "cpu" + + logger.info("Loading Qwen3-ASR for SimulStreaming+KV: %s", model_id) + self.model = AutoModel.from_pretrained(model_id, dtype=dtype, device_map=device) + self.model.eval() + self.processor = AutoProcessor.from_pretrained(model_id, fix_mistral_regex=True) + + thinker = self.model.thinker + text_config = thinker.config.text_config + self.num_layers = text_config.num_hidden_layers + self.num_heads = text_config.num_attention_heads + self.num_kv_heads = text_config.num_key_value_heads + self.audio_token_id = thinker.config.audio_token_id + self.device = next(self.model.parameters()).device + self.dtype = next(self.model.parameters()).dtype + self.asr_text_token_id = self.processor.tokenizer.convert_tokens_to_ids("") + + # EOS tokens + self.eos_ids = {151645, 151643} + if self.processor.tokenizer.eos_token_id is not None: + self.eos_ids.add(self.processor.tokenizer.eos_token_id) + + logger.info( + "Qwen3-ASR loaded: %d layers x %d heads, device=%s", + self.num_layers, self.num_heads, self.device, + ) + + def _load_alignment_heads(self, path): + max_heads = self.cfg.max_alignment_heads + if path and Path(path).exists(): + with open(path) as f: + data = json.load(f) + all_heads = [tuple(h) for h in data["alignment_heads_compact"]] + heads = all_heads[:max_heads] + logger.info("Loaded top %d alignment heads from %s", len(heads), path) + return heads + default_heads = [] + start_layer = self.num_layers * 3 // 4 + for layer in range(start_layer, self.num_layers): + for head in range(self.num_heads): + default_heads.append((layer, head)) + logger.warning("No alignment heads file. Using %d default heads.", len(default_heads)) + return default_heads[:max_heads] + + def _warmup(self, audio): + try: + audio = audio[:SAMPLE_RATE * 2] + msgs = [{"role": "system", "content": ""}, {"role": "user", "content": [{"type": "audio", "audio": ""}]}] + text_prompt = self.processor.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + inputs = self.processor(text=[text_prompt], audio=[audio], return_tensors="pt", padding=True) + inputs = inputs.to(self.device).to(self.dtype) + with torch.inference_mode(): + self.model.thinker.generate(**inputs, max_new_tokens=5, do_sample=False) + logger.info("Warmup complete") + except Exception as e: + logger.warning("Warmup failed: %s", e) + + def transcribe(self, audio): + pass + + +class Qwen3SimulKVOnlineProcessor: + """ + Per-session online processor with KV cache reuse. + + Key optimization: instead of calling generate() each time (which does + full prefill), we maintain a DynamicCache and do incremental prefill + + manual greedy decoding with alignment head hooks. + """ + + SAMPLING_RATE = 16000 + MIN_DURATION_REAL_SILENCE = 5 + + def __init__(self, asr: Qwen3SimulKVASR, logfile=sys.stderr): + self.asr = asr + self.logfile = logfile + self.end = 0.0 + self.buffer: List[ASRToken] = [] + self.state = Qwen3SimulKVState() + self._build_prompt_template() + + def _build_prompt_template(self): + from whisperlivekit.qwen3_asr import WHISPER_TO_QWEN3_LANGUAGE + msgs = [ + {"role": "system", "content": ""}, + {"role": "user", "content": [{"type": "audio", "audio": ""}]}, + ] + self._base_prompt = self.asr.processor.apply_chat_template( + msgs, add_generation_prompt=True, tokenize=False, + ) + lan = self.asr.cfg.language + if lan and lan != "auto": + lang_name = WHISPER_TO_QWEN3_LANGUAGE.get(lan, lan) + self._base_prompt += f"language {lang_name}" + + @property + def speaker(self): + return self.state.speaker + + @speaker.setter + def speaker(self, value): + self.state.speaker = value + + @property + def global_time_offset(self): + return self.state.global_time_offset + + @global_time_offset.setter + def global_time_offset(self, value): + self.state.global_time_offset = value + + def insert_audio_chunk(self, audio: np.ndarray, audio_stream_end_time: float): + self.end = audio_stream_end_time + self.state.audio_buffer = np.append(self.state.audio_buffer, audio) + + max_samples = int(self.asr.cfg.audio_max_len * self.SAMPLING_RATE) + if len(self.state.audio_buffer) > max_samples: + trim = len(self.state.audio_buffer) - max_samples + self.state.audio_buffer = self.state.audio_buffer[trim:] + self.state.cumulative_time_offset += trim / self.SAMPLING_RATE + self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim) + self.state.audio_cache.reset() + self.state.reset_kv() # Must invalidate KV when audio is trimmed + + def start_silence(self) -> Tuple[List[ASRToken], float]: + all_tokens = [] + for _ in range(5): + tokens, _ = self.process_iter(is_last=True) + if not tokens: + break + all_tokens.extend(tokens) + return all_tokens, self.end + + def end_silence(self, silence_duration: float, offset: float): + self.end += silence_duration + long_silence = silence_duration >= self.MIN_DURATION_REAL_SILENCE + if not long_silence: + gap_len = int(self.SAMPLING_RATE * silence_duration) + if gap_len > 0: + self.state.audio_buffer = np.append( + self.state.audio_buffer, np.zeros(gap_len, dtype=np.float32), + ) + else: + self.state = Qwen3SimulKVState() + self.state.global_time_offset = silence_duration + offset + + def new_speaker(self, change_speaker: ChangeSpeaker): + self.process_iter(is_last=True) + self.state = Qwen3SimulKVState() + self.state.speaker = change_speaker.speaker + self.state.global_time_offset = change_speaker.start + + def get_buffer(self) -> Transcript: + return Transcript.from_tokens(tokens=self.buffer, sep='') + + def _encode_audio(self) -> Tuple[torch.Tensor, int]: + """Encode full audio buffer, with caching for stable windows.""" + asr = self.asr + state = self.state + + from qwen_asr.core.transformers_backend.processing_qwen3_asr import ( + _get_feat_extract_output_lengths, + ) + + feat_out = asr.processor.feature_extractor( + [state.audio_buffer], sampling_rate=16000, + padding=True, truncation=False, + return_attention_mask=True, return_tensors="pt", + ) + input_features = feat_out["input_features"].to(asr.device).to(asr.dtype) + feature_attention_mask = feat_out["attention_mask"].to(asr.device) + total_mel_frames = feature_attention_mask.sum().item() + total_audio_tokens = _get_feat_extract_output_lengths( + torch.tensor(total_mel_frames), + ).item() + + cache = state.audio_cache + audio_cfg = asr.model.thinker.audio_tower.config + n_window_infer = getattr(audio_cfg, "n_window_infer", 400) + n_complete_windows = total_mel_frames // n_window_infer + + if n_complete_windows <= 0 or cache.embeddings is None: + # Full encode + audio_embeds = asr.model.thinker.get_audio_features( + input_features, feature_attention_mask=feature_attention_mask, + ) + if audio_embeds.dim() == 3: + audio_embeds = audio_embeds[0] + stable_mel = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0 + stable_tokens = _get_feat_extract_output_lengths( + torch.tensor(stable_mel), + ).item() if stable_mel > 0 else 0 + else: + stable_mel = n_complete_windows * n_window_infer + stable_tokens = _get_feat_extract_output_lengths( + torch.tensor(stable_mel), + ).item() + + if cache.stable_tokens > 0 and cache.stable_tokens <= stable_tokens: + cached_prefix = cache.embeddings[:stable_tokens] if cache.embeddings.dim() == 2 else cache.embeddings[0, :stable_tokens] + tail_features = input_features[:, :, stable_mel:] + tail_mel_frames = total_mel_frames - stable_mel + if tail_mel_frames > 0: + tail_mask = torch.ones( + (1, tail_features.shape[2]), + dtype=feature_attention_mask.dtype, + device=feature_attention_mask.device, + ) + tail_embeds = asr.model.thinker.get_audio_features( + tail_features, feature_attention_mask=tail_mask, + ) + if tail_embeds.dim() == 3: + tail_embeds = tail_embeds[0] + audio_embeds = torch.cat([cached_prefix, tail_embeds], dim=0) + else: + audio_embeds = cached_prefix + else: + audio_embeds = asr.model.thinker.get_audio_features( + input_features, feature_attention_mask=feature_attention_mask, + ) + if audio_embeds.dim() == 3: + audio_embeds = audio_embeds[0] + + # Update cache + cache.embeddings = audio_embeds if audio_embeds.dim() == 2 else audio_embeds[0] + cache.encoded_samples = len(state.audio_buffer) + cache.encoded_mel_frames = total_mel_frames + stable_mel_final = n_complete_windows * n_window_infer if n_complete_windows > 0 else 0 + cache.stable_tokens = _get_feat_extract_output_lengths( + torch.tensor(stable_mel_final), + ).item() if stable_mel_final > 0 else 0 + + return audio_embeds, total_audio_tokens + + def _build_full_inputs(self, audio_embeds: torch.Tensor) -> dict: + """Build full input embeddings from prompt + audio embeddings + context.""" + asr = self.asr + state = self.state + thinker = asr.model.thinker + + from qwen_asr.core.transformers_backend.processing_qwen3_asr import ( + _get_feat_extract_output_lengths, + ) + + n_audio_tokens = audio_embeds.shape[0] + + prompt_with_placeholders = asr.processor.replace_multimodal_special_tokens( + [self._base_prompt], iter([n_audio_tokens]), + )[0] + text_ids = asr.processor.tokenizer( + [prompt_with_placeholders], return_tensors="pt", padding=True, + ) + input_ids = text_ids["input_ids"].to(asr.device) + attention_mask = text_ids.get("attention_mask") + if attention_mask is not None: + attention_mask = attention_mask.to(asr.device) + + # Append committed context tokens + if state.committed_token_ids: + ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:] + ctx_ids = torch.tensor([ctx], dtype=input_ids.dtype, device=input_ids.device) + input_ids = torch.cat([input_ids, ctx_ids], dim=1) + if attention_mask is not None: + ctx_mask = torch.ones_like(ctx_ids) + attention_mask = torch.cat([attention_mask, ctx_mask], dim=1) + + # Build inputs_embeds + inputs_embeds = thinker.get_input_embeddings()(input_ids) + audio_mask = (input_ids == asr.audio_token_id) + n_placeholders = audio_mask.sum().item() + + if n_placeholders != n_audio_tokens: + logger.warning("Audio token mismatch: %d vs %d", n_placeholders, n_audio_tokens) + return None + + audio_embeds_cast = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + expand_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(expand_mask, audio_embeds_cast) + + # Find audio token range + audio_positions = audio_mask[0].nonzero(as_tuple=True)[0] + audio_start = audio_positions[0].item() + audio_end = audio_positions[-1].item() + 1 + + return { + "input_ids": input_ids, + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "audio_start": audio_start, + "audio_end": audio_end, + "n_audio_tokens": n_audio_tokens, + } + + @torch.inference_mode() + def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: + audio_duration = len(self.state.audio_buffer) / self.SAMPLING_RATE + if audio_duration < self.asr.cfg.audio_min_len: + return [], self.end + + new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples + min_new_seconds = 1.0 + if not is_last and new_samples < int(min_new_seconds * self.SAMPLING_RATE): + return [], self.end + + self.state.last_infer_samples = len(self.state.audio_buffer) + + try: + timestamped_words = self._infer(is_last) + except Exception as e: + logger.exception("Inference error: %s", e) + self.state.reset_kv() + return [], self.end + + if not timestamped_words: + return [], self.end + + self.buffer = [] + return timestamped_words, self.end + + def _infer(self, is_last: bool) -> List[ASRToken]: + """Run inference with KV cache reuse and alignment-head stopping.""" + asr = self.asr + state = self.state + thinker = asr.model.thinker + + # Step 1: Encode audio (with caching) + audio_embeds, n_audio_tokens_total = self._encode_audio() + + # Step 2: Build full inputs + full_inputs = self._build_full_inputs(audio_embeds) + if full_inputs is None: + state.reset_kv() + return [] + + input_ids = full_inputs["input_ids"] + inputs_embeds = full_inputs["inputs_embeds"] + attention_mask = full_inputs["attention_mask"] + audio_start = full_inputs["audio_start"] + audio_end = full_inputs["audio_end"] + n_audio_tokens = full_inputs["n_audio_tokens"] + audio_duration = len(state.audio_buffer) / self.SAMPLING_RATE + + # Step 3: Full prefill (we always re-prefill since audio tokens change) + # Future optimization: partial prefill when only tail audio changes + out = thinker( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + use_cache=True, + ) + kv_cache = out.past_key_values + prompt_len = input_ids.shape[1] + + # Step 4: Greedy decode with alignment head stopping + border_threshold = max(2, int(n_audio_tokens * asr.cfg.border_fraction)) + rewind_threshold = max(2, int(n_audio_tokens * asr.cfg.rewind_fraction)) + last_attend_frame = state.last_attend_frame + + # Install hooks for alignment head attention extraction + decoder_layers = thinker.model.layers + num_kv_heads = asr.num_kv_heads + num_heads = asr.num_heads + gqa_ratio = num_heads // num_kv_heads + + from qwen_asr.core.transformers_backend.modeling_qwen3_asr import apply_rotary_pos_emb + + per_step_frames: List[List[int]] = [] + current_step_frames: List[int] = [] + hooks = [] + + def _make_attn_hook(layer_idx): + head_indices = asr.heads_by_layer[layer_idx] + def hook_fn(module, args, kwargs, output): + hidden_states = kwargs.get('hidden_states') + if hidden_states is None: + hidden_states = args[0] if args else None + if hidden_states is None or hidden_states.shape[1] != 1: + return + position_embeddings = kwargs.get('position_embeddings') + if position_embeddings is None and len(args) > 1: + position_embeddings = args[1] + past_kv = kwargs.get('past_key_values') + if position_embeddings is None or past_kv is None: + return + + hidden_shape = (*hidden_states.shape[:-1], -1, module.head_dim) + q = module.q_norm(module.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + cos, sin = position_embeddings + q, _ = apply_rotary_pos_emb(q, q, cos, sin) + + cache_layer = past_kv.layers[module.layer_idx] + k = cache_layer.keys + if k is None or audio_end > k.shape[2]: + return + + for h_idx in head_indices: + if h_idx >= q.shape[1]: + continue + kv_h_idx = h_idx // gqa_ratio + q_h = q[0, h_idx, 0] + k_audio = k[0, kv_h_idx, audio_start:audio_end] + scores = torch.matmul(k_audio, q_h) + frame = scores.argmax().item() + current_step_frames.append(frame) + return hook_fn + + for layer_idx in asr.heads_by_layer: + if layer_idx < len(decoder_layers): + h = decoder_layers[layer_idx].self_attn.register_forward_hook( + _make_attn_hook(layer_idx), with_kwargs=True, + ) + hooks.append(h) + + try: + # Greedy decoding with alignment-based stopping + next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True) + generated_ids = [] + border_stop_step = None + tokens_per_sec = 6 + if is_last: + max_tokens = min(int(audio_duration * tokens_per_sec) + 10, 120) + else: + new_audio_secs = (len(state.audio_buffer) - state.last_infer_samples) / self.SAMPLING_RATE + max_tokens = min(int(max(new_audio_secs, 1.0) * tokens_per_sec) + 5, 40) + + for step in range(max_tokens): + tid = next_token.item() + if tid in asr.eos_ids: + break + generated_ids.append(tid) + + # Collect alignment frames for this step + if current_step_frames: + per_step_frames.append(current_step_frames) + current_step_frames = [] + + # Check stopping criteria (after 3 tokens) + if not is_last and len(per_step_frames) >= 3: + latest = per_step_frames[-1] + if latest: + frames_sorted = sorted(latest) + attended = frames_sorted[len(frames_sorted) // 2] + + if last_attend_frame - attended > rewind_threshold: + border_stop_step = max(0, len(per_step_frames) - 2) + break + + last_attend_frame = attended + + if (n_audio_tokens - attended) <= border_threshold: + border_stop_step = len(per_step_frames) - 1 + break + + # Next token + out = thinker( + input_ids=next_token, + past_key_values=kv_cache, + use_cache=True, + ) + kv_cache = out.past_key_values + next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True) + + # Flush remaining frames + if current_step_frames: + per_step_frames.append(current_step_frames) + finally: + for h in hooks: + h.remove() + + state.last_attend_frame = last_attend_frame + + if not generated_ids: + return [] + + # Strip metadata prefix ( token) + all_generated = torch.tensor(generated_ids, device=asr.device) + num_gen = len(generated_ids) + asr_text_id = asr.asr_text_token_id + metadata_offset = 0 + for i in range(min(num_gen, 10)): + if generated_ids[i] == asr_text_id: + if state.detected_language is None and i > 0: + from whisperlivekit.qwen3_asr import QWEN3_TO_WHISPER_LANGUAGE + prefix_text = asr.processor.tokenizer.decode( + generated_ids[:i], skip_special_tokens=True, + ).strip() + parts = prefix_text.split() + if len(parts) >= 2: + lang_name = parts[-1] + if lang_name.lower() != "none": + state.detected_language = QWEN3_TO_WHISPER_LANGUAGE.get( + lang_name, lang_name.lower(), + ) + metadata_offset = i + 1 + break + + if metadata_offset > 0: + generated_ids = generated_ids[metadata_offset:] + num_gen -= metadata_offset + per_step_frames = per_step_frames[metadata_offset:] + + if num_gen <= 0: + return [] + + # Determine emit count + if border_stop_step is not None: + emit_up_to = min(border_stop_step, num_gen) + else: + emit_up_to = num_gen + + emitted_ids = generated_ids[:emit_up_to] + if not emitted_ids: + return [] + + # Build timestamped words + words = self._build_timestamped_words( + emitted_ids, per_step_frames, emit_up_to, + n_audio_tokens, audio_duration, + ) + + state.committed_word_count += len(words) + # Include metadata in committed tokens for context + all_emitted = generated_ids[:emit_up_to] + if metadata_offset > 0: + all_emitted = generated_ids[:emit_up_to] # already stripped + state.committed_token_ids.extend(all_emitted) + + return words + + def _build_timestamped_words( + self, + generated_ids: list, + step_frames: List[List[int]], + emit_up_to: int, + n_audio_tokens: int, + audio_duration: float, + ) -> List[ASRToken]: + asr = self.asr + state = self.state + + per_token_frame = [] + for step in range(emit_up_to): + if step < len(step_frames) and step_frames[step]: + frames = sorted(step_frames[step]) + per_token_frame.append(frames[len(frames) // 2]) + else: + per_token_frame.append(None) + + tokenizer = asr.processor.tokenizer + full_text = tokenizer.decode(generated_ids[:emit_up_to], skip_special_tokens=True) + text_words = full_text.split() + + all_frames = [f for f in per_token_frame if f is not None] + words = [] + for wi, word in enumerate(text_words): + if all_frames: + frac = wi / max(len(text_words), 1) + frame_idx = min(int(frac * len(all_frames)), len(all_frames) - 1) + frame = all_frames[frame_idx] + else: + frame = None + words.append((word, frame)) + + tokens = [] + for i, (text, frame) in enumerate(words): + text = text.strip() + if not text: + continue + + if frame is not None and n_audio_tokens > 0: + timestamp = ( + frame / n_audio_tokens * audio_duration + + state.cumulative_time_offset + ) + else: + timestamp = ( + (i / max(len(words), 1)) * audio_duration + + state.cumulative_time_offset + ) + + is_very_first_word = (i == 0 and state.committed_word_count == 0) + display_text = text if is_very_first_word else " " + text + + token = ASRToken( + start=round(timestamp, 2), + end=round(timestamp + 0.1, 2), + text=display_text, + speaker=state.speaker, + detected_language=state.detected_language, + ).with_offset(state.global_time_offset) + tokens.append(token) + + return tokens + + def warmup(self, audio: np.ndarray, init_prompt: str = ""): + try: + self.state.audio_buffer = audio[:SAMPLE_RATE] + self.process_iter(is_last=True) + self.state = Qwen3SimulKVState() + except Exception as e: + logger.warning("Warmup failed: %s", e) + self.state = Qwen3SimulKVState() + + def finish(self) -> Tuple[List[ASRToken], float]: + all_tokens = [] + for _ in range(5): + tokens, _ = self.process_iter(is_last=True) + if not tokens: + break + all_tokens.extend(tokens) + return all_tokens, self.end