mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
17 lines
847 B
Python
17 lines
847 B
Python
from .whisper.decoding import PyTorchInference
|
|
|
|
# extention of PyTorchInference for beam search
|
|
class BeamPyTorchInference(PyTorchInference):
|
|
|
|
def _kv_modules(self):
|
|
key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks]
|
|
value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks]
|
|
return key_modules + value_modules
|
|
|
|
def rearrange_kv_cache(self, source_indices):
|
|
if source_indices != list(range(len(source_indices))):
|
|
for module_cache_id in self._kv_modules():
|
|
self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach()
|
|
from torch import Tensor
|
|
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
|
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) |