mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
from torch import Tensor
|
|
|
|
from whisperlivekit.whisper.decoding import PyTorchInference
|
|
|
|
|
|
class BeamPyTorchInference(PyTorchInference):
|
|
"""Extension of PyTorchInference for beam search with cross-attention support."""
|
|
|
|
def _kv_cache_ids(self):
|
|
"""Get cache_id strings for self-attention key/value modules."""
|
|
key_ids = [block.attn.key_cache_id for block in self.model.decoder.blocks]
|
|
value_ids = [block.attn.value_cache_id for block in self.model.decoder.blocks]
|
|
return key_ids + value_ids
|
|
|
|
def rearrange_kv_cache(self, source_indices):
|
|
if source_indices != list(range(len(source_indices))):
|
|
for cache_id in self._kv_cache_ids():
|
|
if cache_id in self.kv_cache:
|
|
self.kv_cache[cache_id] = self.kv_cache[cache_id][source_indices].detach()
|
|
|
|
def logits(
|
|
self,
|
|
tokens: Tensor,
|
|
audio_features: Tensor,
|
|
return_cross_attn: bool = False,
|
|
):
|
|
"""Get logits, optionally returning cross-attention weights."""
|
|
return self.model.decoder(
|
|
tokens, audio_features,
|
|
kv_cache=self.kv_cache,
|
|
return_cross_attn=return_cross_attn,
|
|
) |