mirror of
https://github.com/QuentinFuxa/WhisperLiveKit.git
synced 2026-03-07 22:33:36 +00:00
Pure-MLX implementation of Voxtral Mini 4B Realtime for low-latency speech transcription on Apple Silicon. Avoids the transformers/torch overhead and runs at 0.18-0.32x real-time factor. - voxtral_mlx/model.py: MLX model with spectrogram, encoder, decoder - voxtral_mlx/loader.py: model loading with 6-bit quantized weights - voxtral_mlx/spectrogram.py: mel spectrogram computation in MLX - voxtral_mlx_asr.py: VoxtralASR adapter for the AudioProcessor pipeline
535 lines
19 KiB
Python
535 lines
19 KiB
Python
"""
|
|
Voxtral Realtime MLX model — encoder, decoder, adapter, and top-level model.
|
|
|
|
Architecture:
|
|
audio → StreamingEncoder → EncoderToDecoderAdapter → TextDecoder → logits
|
|
with DelayEmbedding providing time-conditioning to the decoder.
|
|
|
|
The model supports both batch inference (full audio) and incremental streaming
|
|
(one chunk at a time with cached encoder/decoder state).
|
|
"""
|
|
|
|
import math
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# KV Cache
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class SlidingKVCache:
|
|
"""Bounded key-value cache with rotating buffer for sliding-window attention.
|
|
|
|
Uses in-place writes for single-token autoregressive steps and
|
|
concatenation for multi-token prefills. Pre-allocates in blocks of
|
|
``alloc_step`` entries to reduce repeated allocation.
|
|
"""
|
|
|
|
alloc_step = 256
|
|
|
|
def __init__(self, capacity: int):
|
|
self.capacity = capacity
|
|
self.keys = None
|
|
self.values = None
|
|
self._offset = 0
|
|
self._write_idx = 0
|
|
|
|
@property
|
|
def offset(self) -> int:
|
|
return self._offset
|
|
|
|
# -- helpers --
|
|
|
|
def _reorder(self, buf):
|
|
"""Return *buf* in temporal order (unwrap the circular buffer)."""
|
|
if self._write_idx == buf.shape[2]:
|
|
return buf
|
|
if self._write_idx < self._offset:
|
|
return mx.concatenate(
|
|
[buf[..., self._write_idx:, :], buf[..., : self._write_idx, :]],
|
|
axis=2,
|
|
)
|
|
return buf[..., : self._write_idx, :]
|
|
|
|
def _drop_oldest(self, buf, n_drop, tail=None):
|
|
parts = [buf[..., n_drop:, :]] if n_drop > 0 else [buf]
|
|
if tail is not None:
|
|
parts.append(tail)
|
|
return mx.concatenate(parts, axis=2)
|
|
|
|
# -- update strategies --
|
|
|
|
def _append_concat(self, k, v):
|
|
"""Multi-token update via concatenation (used during prefill)."""
|
|
if self.keys is None:
|
|
self.keys, self.values = k, v
|
|
else:
|
|
self.keys = self._reorder(self.keys)
|
|
self.values = self._reorder(self.values)
|
|
self._write_idx = self.keys.shape[2]
|
|
overflow = self._write_idx - self.capacity + 1
|
|
self.keys = self._drop_oldest(self.keys, overflow, k)
|
|
self.values = self._drop_oldest(self.values, overflow, v)
|
|
self._offset += k.shape[2]
|
|
self._write_idx = self.keys.shape[2]
|
|
return self.keys, self.values
|
|
|
|
def _write_inplace(self, k, v):
|
|
"""Single-token update via in-place write (autoregressive step)."""
|
|
B, n_heads, S, dim_k = k.shape
|
|
dim_v = v.shape[3]
|
|
prev = self._offset
|
|
|
|
if self.keys is None or (
|
|
prev >= self.keys.shape[2] and self.keys.shape[2] < self.capacity
|
|
):
|
|
n_new = min(self.alloc_step, self.capacity - prev)
|
|
fresh_k = mx.zeros((B, n_heads, n_new, dim_k), k.dtype)
|
|
fresh_v = mx.zeros((B, n_heads, n_new, dim_v), v.dtype)
|
|
if self.keys is not None:
|
|
self.keys = mx.concatenate([self.keys, fresh_k], axis=2)
|
|
self.values = mx.concatenate([self.values, fresh_v], axis=2)
|
|
else:
|
|
self.keys, self.values = fresh_k, fresh_v
|
|
self._write_idx = prev
|
|
|
|
overflow = self.keys.shape[2] - self.capacity
|
|
if overflow > 0:
|
|
self.keys = self._drop_oldest(self.keys, overflow)
|
|
self.values = self._drop_oldest(self.values, overflow)
|
|
self._write_idx = self.capacity
|
|
|
|
if self._write_idx == self.capacity:
|
|
self._write_idx = 0
|
|
|
|
self.keys[..., self._write_idx : self._write_idx + S, :] = k
|
|
self.values[..., self._write_idx : self._write_idx + S, :] = v
|
|
self._offset += S
|
|
self._write_idx += S
|
|
|
|
if self._offset < self.capacity:
|
|
return (
|
|
self.keys[..., : self._offset, :],
|
|
self.values[..., : self._offset, :],
|
|
)
|
|
return self.keys, self.values
|
|
|
|
# -- public API --
|
|
|
|
def update_and_fetch(self, k, v):
|
|
if k.shape[2] == 1:
|
|
return self._write_inplace(k, v)
|
|
return self._append_concat(k, v)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Encoder components
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class CausalConv(nn.Module):
|
|
"""1-D causal convolution (left-padded so no future leakage)."""
|
|
|
|
def __init__(self, channels_in: int, channels_out: int, kernel: int, stride: int = 1):
|
|
super().__init__()
|
|
self.stride = stride
|
|
self.kernel = kernel
|
|
self.left_pad = kernel - stride
|
|
self.weight = mx.zeros((channels_out, kernel, channels_in))
|
|
self.bias = mx.zeros((channels_out,))
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
if self.left_pad > 0:
|
|
x = mx.pad(x, [(0, 0), (self.left_pad, 0), (0, 0)])
|
|
return mx.conv1d(x, self.weight, stride=self.stride) + self.bias
|
|
|
|
|
|
class _EncoderSelfAttention(nn.Module):
|
|
def __init__(self, dim: int, n_heads: int, head_dim: int, rope_theta: float):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
self.head_dim = head_dim
|
|
self.scale = head_dim**-0.5
|
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
|
self.k_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
|
self.v_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
|
|
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
|
|
self.rope_theta = rope_theta
|
|
|
|
def __call__(self, x, mask, cache=None):
|
|
B, L, _ = x.shape
|
|
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
k = self.k_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
v = self.v_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
|
|
pos = cache.offset if cache is not None else 0
|
|
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
|
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
|
|
|
if cache is not None:
|
|
k, v = cache.update_and_fetch(k, v)
|
|
|
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
|
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
|
|
|
|
|
|
class _EncoderFFN(nn.Module):
|
|
"""SwiGLU feed-forward for encoder layers."""
|
|
|
|
def __init__(self, dim: int, hidden: int):
|
|
super().__init__()
|
|
self.gate = nn.Linear(dim, hidden, bias=False)
|
|
self.up = nn.Linear(dim, hidden, bias=False)
|
|
self.down = nn.Linear(hidden, dim, bias=True)
|
|
|
|
def __call__(self, x):
|
|
return self.down(nn.silu(self.gate(x)) * self.up(x))
|
|
|
|
|
|
class _EncoderBlock(nn.Module):
|
|
def __init__(self, dim, n_heads, head_dim, hidden, rope_theta):
|
|
super().__init__()
|
|
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
|
|
self.self_attn = _EncoderSelfAttention(dim, n_heads, head_dim, rope_theta)
|
|
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
|
|
self.ffn = _EncoderFFN(dim, hidden)
|
|
|
|
def __call__(self, x, mask, cache=None):
|
|
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache=cache)
|
|
x = x + self.ffn(self.pre_ffn_norm(x))
|
|
return x
|
|
|
|
|
|
class StreamingEncoder(nn.Module):
|
|
"""Causal Whisper-style encoder with two causal convolutions followed by
|
|
a stack of transformer blocks. Supports both full-sequence and
|
|
incremental (streaming) forward passes."""
|
|
|
|
def __init__(
|
|
self,
|
|
mel_channels: int = 128,
|
|
dim: int = 1280,
|
|
n_layers: int = 32,
|
|
n_heads: int = 32,
|
|
head_dim: int = 64,
|
|
hidden_dim: int = 5120,
|
|
rope_theta: float = 1e6,
|
|
sliding_window: int = 750,
|
|
):
|
|
super().__init__()
|
|
self.conv1 = CausalConv(mel_channels, dim, kernel=3, stride=1)
|
|
self.conv2 = CausalConv(dim, dim, kernel=3, stride=2)
|
|
self.blocks = [
|
|
_EncoderBlock(dim, n_heads, head_dim, hidden_dim, rope_theta)
|
|
for _ in range(n_layers)
|
|
]
|
|
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
|
|
self.sliding_window = sliding_window
|
|
|
|
# -- full-sequence --
|
|
|
|
def _apply_convs(self, mel: mx.array) -> mx.array:
|
|
x = mel.T[None, :, :] # [1, T, mel_channels]
|
|
x = nn.gelu(self.conv1(x))
|
|
x = nn.gelu(self.conv2(x))
|
|
return x
|
|
|
|
def forward(self, mel: mx.array) -> mx.array:
|
|
x = self._apply_convs(mel.astype(self.conv1.weight.dtype))
|
|
for blk in self.blocks:
|
|
x = blk(x, mask="causal")
|
|
return self.final_norm(x)
|
|
|
|
# -- incremental (streaming) --
|
|
|
|
def forward_conv_incremental(self, x_in, tail1, tail2):
|
|
"""Process new mel frames through the two causal convs using cached tails.
|
|
|
|
Args:
|
|
x_in: [1, N, mel_channels]
|
|
tail1: [1, pad1, mel_channels] or None (first call)
|
|
tail2: [1, pad2, dim] or None (first call)
|
|
|
|
Returns:
|
|
(out, new_tail1, new_tail2)
|
|
"""
|
|
# Conv1 (kernel=3, stride=1 → left_pad=2)
|
|
if tail1 is not None:
|
|
c1_in = mx.concatenate([tail1, x_in], axis=1)
|
|
else:
|
|
c1_in = mx.pad(x_in, [(0, 0), (self.conv1.left_pad, 0), (0, 0)])
|
|
new_tail1 = x_in[:, -self.conv1.left_pad :, :]
|
|
c1_out = nn.gelu(
|
|
mx.conv1d(c1_in, self.conv1.weight, stride=self.conv1.stride) + self.conv1.bias
|
|
)
|
|
|
|
# Conv2 (kernel=3, stride=2 → left_pad=1)
|
|
if tail2 is not None:
|
|
c2_in = mx.concatenate([tail2, c1_out], axis=1)
|
|
else:
|
|
c2_in = mx.pad(c1_out, [(0, 0), (self.conv2.left_pad, 0), (0, 0)])
|
|
new_tail2 = c1_out[:, -self.conv2.left_pad :, :]
|
|
c2_out = nn.gelu(
|
|
mx.conv1d(c2_in, self.conv2.weight, stride=self.conv2.stride) + self.conv2.bias
|
|
)
|
|
|
|
return c2_out, new_tail1, new_tail2
|
|
|
|
def forward_transformer_incremental(self, x, cache_list):
|
|
"""Run transformer blocks with per-layer KV caches."""
|
|
for i, blk in enumerate(self.blocks):
|
|
x = blk(x, mask="causal", cache=cache_list[i])
|
|
return self.final_norm(x)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Decoder components
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _DecoderAttention(nn.Module):
|
|
"""Grouped-query attention for the text decoder."""
|
|
|
|
def __init__(self, dim, n_heads, n_kv_heads, head_dim, rope_theta):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
self.n_kv_heads = n_kv_heads
|
|
self.head_dim = head_dim
|
|
self.scale = head_dim**-0.5
|
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
|
self.rope_theta = rope_theta
|
|
|
|
def __call__(self, x, mask=None, cache=None):
|
|
B, L, _ = x.shape
|
|
q = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
k = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
v = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
|
|
pos = cache.offset if cache is not None else 0
|
|
q = mx.fast.rope(q, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
|
k = mx.fast.rope(k, self.head_dim, traditional=True, base=self.rope_theta, scale=1.0, offset=pos)
|
|
|
|
if cache is not None:
|
|
k, v = cache.update_and_fetch(k, v)
|
|
|
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
|
return self.out_proj(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
|
|
|
|
|
|
class _DecoderFFN(nn.Module):
|
|
"""SwiGLU feed-forward for decoder layers."""
|
|
|
|
def __init__(self, dim, hidden):
|
|
super().__init__()
|
|
self.gate = nn.Linear(dim, hidden, bias=False)
|
|
self.up = nn.Linear(dim, hidden, bias=False)
|
|
self.down = nn.Linear(hidden, dim, bias=False)
|
|
|
|
def __call__(self, x):
|
|
return self.down(nn.silu(self.gate(x)) * self.up(x))
|
|
|
|
|
|
class AdaptiveScaling(nn.Module):
|
|
"""Small MLP that produces a multiplicative scale from the delay embedding,
|
|
used to condition the FFN on the streaming delay."""
|
|
|
|
def __init__(self, dim, bottleneck):
|
|
super().__init__()
|
|
self.proj_in = nn.Linear(dim, bottleneck, bias=False)
|
|
self.proj_out = nn.Linear(bottleneck, dim, bias=False)
|
|
|
|
def __call__(self, cond):
|
|
return self.proj_out(nn.gelu(self.proj_in(cond)))
|
|
|
|
|
|
class _DecoderBlock(nn.Module):
|
|
def __init__(self, dim, n_heads, n_kv_heads, head_dim, hidden, rope_theta, cond_dim):
|
|
super().__init__()
|
|
self.pre_attn_norm = nn.RMSNorm(dim, eps=1e-5)
|
|
self.self_attn = _DecoderAttention(dim, n_heads, n_kv_heads, head_dim, rope_theta)
|
|
self.adaptive_scale = AdaptiveScaling(dim, cond_dim)
|
|
self.pre_ffn_norm = nn.RMSNorm(dim, eps=1e-5)
|
|
self.ffn = _DecoderFFN(dim, hidden)
|
|
|
|
def __call__(self, x, delay_cond, mask=None, cache=None):
|
|
x = x + self.self_attn(self.pre_attn_norm(x), mask, cache)
|
|
scaled = self.pre_ffn_norm(x) * (1.0 + self.adaptive_scale(delay_cond))
|
|
x = x + self.ffn(scaled)
|
|
return x
|
|
|
|
|
|
class TextDecoder(nn.Module):
|
|
"""Mistral-style causal language model with adaptive time-conditioning."""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int = 3072,
|
|
n_layers: int = 26,
|
|
n_heads: int = 32,
|
|
n_kv_heads: int = 8,
|
|
head_dim: int = 128,
|
|
hidden_dim: int = 9216,
|
|
vocab_size: int = 131072,
|
|
rope_theta: float = 1e6,
|
|
cond_dim: int = 32,
|
|
):
|
|
super().__init__()
|
|
self.token_embedding = nn.Embedding(vocab_size, dim)
|
|
self.blocks = [
|
|
_DecoderBlock(dim, n_heads, n_kv_heads, head_dim, hidden_dim, rope_theta, cond_dim)
|
|
for _ in range(n_layers)
|
|
]
|
|
self.final_norm = nn.RMSNorm(dim, eps=1e-5)
|
|
|
|
def embed(self, token_ids: mx.array) -> mx.array:
|
|
return self.token_embedding(token_ids)
|
|
|
|
def __call__(self, x, delay_cond, mask=None, cache=None):
|
|
delay_cond = delay_cond.astype(x.dtype)
|
|
for i, blk in enumerate(self.blocks):
|
|
blk_cache = cache[i] if cache is not None else None
|
|
x = blk(x, delay_cond, mask, blk_cache)
|
|
x = self.final_norm(x)
|
|
return self.token_embedding.as_linear(x)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Adapter & embeddings
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class EncoderToDecoderAdapter(nn.Module):
|
|
"""Two-layer projection from encoder space to decoder space."""
|
|
|
|
def __init__(self, enc_dim: int, dec_dim: int):
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(enc_dim, dec_dim, bias=False)
|
|
self.linear2 = nn.Linear(dec_dim, dec_dim, bias=False)
|
|
|
|
def __call__(self, x):
|
|
return self.linear2(nn.gelu(self.linear1(x)))
|
|
|
|
|
|
class DelayEmbedding(nn.Module):
|
|
"""Sinusoidal embedding that encodes the streaming delay as a conditioning
|
|
vector for the decoder's adaptive scaling."""
|
|
|
|
def __init__(self, dim: int = 3072, theta: float = 10000.0):
|
|
super().__init__()
|
|
self.dim = dim
|
|
half = dim // 2
|
|
freqs = mx.exp(-math.log(theta) * mx.arange(half, dtype=mx.float32) / half)
|
|
self._freqs = freqs
|
|
|
|
def __call__(self, delay: mx.array) -> mx.array:
|
|
t = delay.reshape(-1, 1).astype(mx.float32)
|
|
angles = t * self._freqs
|
|
return mx.concatenate([mx.cos(angles), mx.sin(angles)], axis=-1)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Top-level model
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class VoxtralMLXModel(nn.Module):
|
|
"""Top-level Voxtral Realtime model wiring encoder, adapter, and decoder."""
|
|
|
|
def __init__(self, config: dict):
|
|
super().__init__()
|
|
|
|
enc_cfg = config["multimodal"]["whisper_model_args"]["encoder_args"]
|
|
audio_cfg = enc_cfg["audio_encoding_args"]
|
|
ds_factor = config["multimodal"]["whisper_model_args"]["downsample_args"]["downsample_factor"]
|
|
|
|
self.encoder = StreamingEncoder(
|
|
mel_channels=audio_cfg["num_mel_bins"],
|
|
dim=enc_cfg["dim"],
|
|
n_layers=enc_cfg["n_layers"],
|
|
n_heads=enc_cfg["n_heads"],
|
|
head_dim=enc_cfg["head_dim"],
|
|
hidden_dim=enc_cfg["hidden_dim"],
|
|
rope_theta=enc_cfg["rope_theta"],
|
|
sliding_window=enc_cfg["sliding_window"],
|
|
)
|
|
|
|
adapter_input_dim = enc_cfg["dim"] * ds_factor
|
|
decoder_dim = config["dim"]
|
|
cond_bottleneck = config.get("ada_rms_norm_t_cond_dim", 32)
|
|
|
|
self.adapter = EncoderToDecoderAdapter(adapter_input_dim, decoder_dim)
|
|
|
|
self.decoder = TextDecoder(
|
|
dim=decoder_dim,
|
|
n_layers=config["n_layers"],
|
|
n_heads=config["n_heads"],
|
|
n_kv_heads=config["n_kv_heads"],
|
|
head_dim=config["head_dim"],
|
|
hidden_dim=config["hidden_dim"],
|
|
vocab_size=config["vocab_size"],
|
|
rope_theta=config["rope_theta"],
|
|
cond_dim=cond_bottleneck,
|
|
)
|
|
|
|
self.delay_embedding = DelayEmbedding(dim=decoder_dim)
|
|
self.ds_factor = ds_factor
|
|
|
|
# -- batch encode --
|
|
|
|
def encode(self, mel: mx.array) -> mx.array:
|
|
T = mel.shape[1]
|
|
if T % 2 != 0:
|
|
mel = mel[:, 1:]
|
|
|
|
h = self.encoder.forward(mel) # [1, T/2, enc_dim]
|
|
h = h[0]
|
|
|
|
n = h.shape[0]
|
|
trim = n % self.ds_factor
|
|
if trim:
|
|
h = h[trim:]
|
|
n = h.shape[0]
|
|
|
|
h = h.reshape(n // self.ds_factor, -1)
|
|
return self.adapter(h)
|
|
|
|
# -- incremental encode --
|
|
|
|
def encode_incremental(self, new_mel, conv_tail1, conv_tail2, enc_cache, ds_remainder):
|
|
"""Incrementally encode new mel frames.
|
|
|
|
Returns:
|
|
(audio_embeds | None, conv_tail1, conv_tail2, enc_cache, ds_remainder)
|
|
"""
|
|
x = new_mel.T[None, :, :].astype(self.encoder.conv1.weight.dtype)
|
|
|
|
x, conv_tail1, conv_tail2 = self.encoder.forward_conv_incremental(x, conv_tail1, conv_tail2)
|
|
|
|
if enc_cache is None:
|
|
enc_cache = [SlidingKVCache(100_000) for _ in range(len(self.encoder.blocks))]
|
|
|
|
x = self.encoder.forward_transformer_incremental(x, enc_cache)
|
|
x = x[0] # [N, enc_dim]
|
|
|
|
if ds_remainder is not None:
|
|
x = mx.concatenate([ds_remainder, x])
|
|
|
|
n_full = (x.shape[0] // self.ds_factor) * self.ds_factor
|
|
if n_full == 0:
|
|
return None, conv_tail1, conv_tail2, enc_cache, x
|
|
|
|
leftover = x[n_full:] if x.shape[0] > n_full else None
|
|
x = x[:n_full].reshape(n_full // self.ds_factor, -1)
|
|
return self.adapter(x), conv_tail1, conv_tail2, enc_cache, leftover
|
|
|
|
# -- decode --
|
|
|
|
def decode(self, embeddings, delay_cond, mask=None, cache=None):
|
|
return self.decoder(embeddings, delay_cond, mask, cache)
|