chore(sync): merge dev into main (#1379)

* Optimize latency (#1259)

* add attribute:   configs/config.py
	Optimize latency:   tools/rvc_for_realtime.py

* new file:   assets/Synthesizer_inputs.pth

* fix:   configs/config.py
	fix:   tools/rvc_for_realtime.py

* fix bug:   infer/lib/infer_pack/models.py

* new file:   assets/hubert_inputs.pth
	new file:   assets/rmvpe_inputs.pth
	modified:   configs/config.py
	new features:   infer/lib/rmvpe.py
	new features:   tools/jit_export/__init__.py
	new features:   tools/jit_export/get_hubert.py
	new features:   tools/jit_export/get_rmvpe.py
	new features:   tools/jit_export/get_synthesizer.py
	optimize:   tools/rvc_for_realtime.py

* optimize:   tools/jit_export/get_synthesizer.py
	fix bug:   tools/jit_export/__init__.py

* Fixed a bug caused by using half on the CPU:   infer/lib/rmvpe.py
	Fixed a bug caused by using half on the CPU:   tools/jit_export/__init__.py
	Fixed CIRCULAR IMPORT:   tools/jit_export/get_rmvpe.py
	Fixed CIRCULAR IMPORT:   tools/jit_export/get_synthesizer.py
	Fixed a bug caused by using half on the CPU:   tools/rvc_for_realtime.py

* Remove useless code:   infer/lib/rmvpe.py

* Delete gui_v1 copy.py

* Delete .vscode/launch.json

* Delete jit_export_test.py

* Delete tools/rvc_for_realtime copy.py

* Delete configs/config.json

* Delete .gitignore

* Fix exceptions caused by switching inference devices:   infer/lib/rmvpe.py
	Fix exceptions caused by switching inference devices:   tools/jit_export/__init__.py
	Fix exceptions caused by switching inference devices:   tools/rvc_for_realtime.py

* restore

* replace(you can undo this commit)

* remove debug_print

---------

Co-authored-by: Ftps <ftpsflandre@gmail.com>

* Fixed some bugs when exporting ONNX model (#1254)

* fix import (#1280)

* fix import

* lint

* 🎨 同步 locale (#1242)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Fix jit load and import issue (#1282)

* fix jit model loading :   infer/lib/rmvpe.py

* modified:   assets/hubert/.gitignore
	move file:    assets/hubert_inputs.pth -> assets/hubert/hubert_inputs.pth
	modified:   assets/rmvpe/.gitignore
	move file:    assets/rmvpe_inputs.pth -> assets/rmvpe/rmvpe_inputs.pth
	fix import:   gui_v1.py

* feat(workflow): trigger on dev

* feat(workflow): add close-pr on non-dev branch

* Add input wav and delay time monitor for real-time gui (#1293)

* feat(workflow): trigger on dev

* feat(workflow): add close-pr on non-dev branch

* 🎨 同步 locale (#1289)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: edit PR template

* add input wav and delay time monitor

---------

Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>

* Optimize latency using scripted jit (#1291)

* feat(workflow): trigger on dev

* feat(workflow): add close-pr on non-dev branch

* 🎨 同步 locale (#1289)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: edit PR template

* Optimize-latency-using-scripted:   configs/config.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/attentions.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/commons.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/models.py
	Optimize-latency-using-scripted:   infer/lib/infer_pack/modules.py
	Optimize-latency-using-scripted:   infer/lib/jit/__init__.py
	Optimize-latency-using-scripted:   infer/lib/jit/get_hubert.py
	Optimize-latency-using-scripted:   infer/lib/jit/get_rmvpe.py
	Optimize-latency-using-scripted:   infer/lib/jit/get_synthesizer.py
	Optimize-latency-using-scripted:   infer/lib/rmvpe.py
	Optimize-latency-using-scripted:   tools/rvc_for_realtime.py

* modified:   infer/lib/infer_pack/models.py

* fix some bug:   configs/config.py
	fix some bug:   infer/lib/infer_pack/models.py
	fix some bug:   infer/lib/rmvpe.py

* Fixed abnormal reference of logger in multiprocessing:   infer/modules/train/train.py

---------

Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Format code (#1298)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* 🎨 同步 locale (#1299)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: optimize actions

* feat(workflow): add sync dev

* feat: optimize actions

* feat: optimize actions

* feat: optimize actions

* feat: optimize actions

* feat: add jit options (#1303)

Delete useless code:   infer/lib/jit/get_synthesizer.py
	Optimized code:   tools/rvc_for_realtime.py

* Code refactor + re-design inference ui (#1304)

* Code refacor + re-design inference ui

* Fix tabname

* i18n jp

---------

Co-authored-by: Ftps <ftpsflandre@gmail.com>

* feat: optimize actions

* feat: optimize actions

* Update README & en_US locale file (#1309)

* critical: some bug fixes (#1322)

* JIT acceleration switch does not support hot update

* fix padding bug of rmvpe in torch-directml

* fix padding bug of rmvpe in torch-directml

* Fix STFT under torch_directml (#1330)

* chore(format): run black on dev (#1318)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* chore(i18n): sync locale on dev (#1317)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* feat: allow for tta to be passed to uvr (#1361)

* chore(format): run black on dev (#1373)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Added script for automatically download all needed models at install (#1366)

* Delete modules.py

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* chore(i18n): sync locale on dev (#1377)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* chore(format): run black on dev (#1376)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Update IPEX library (#1362)

* Update IPEX library

* Update ipex index

* chore(format): run black on dev (#1378)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: Chengjia Jiang <46401978+ChasonJiang@users.noreply.github.com>
Co-authored-by: Ftps <ftpsflandre@gmail.com>
Co-authored-by: shizuku_nia <102004222+ShizukuNia@users.noreply.github.com>
Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com>
Co-authored-by: yxlllc <33565655+yxlllc@users.noreply.github.com>
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
Co-authored-by: Blaise <133521603+blaise-tk@users.noreply.github.com>
Co-authored-by: Rice Cake <gak141808@gmail.com>
Co-authored-by: AWAS666 <33494149+AWAS666@users.noreply.github.com>
Co-authored-by: Dmitry <nda2911@yandex.ru>
Co-authored-by: Disty0 <47277141+Disty0@users.noreply.github.com>
This commit is contained in:
github-actions[bot]
2023-10-06 17:14:33 +08:00
committed by GitHub
parent fe166e7f3d
commit e9dd11bddb
42 changed files with 2014 additions and 1120 deletions

79
tools/download_models.py Normal file
View File

@@ -0,0 +1,79 @@
import os
from pathlib import Path
import requests
RVC_DOWNLOAD_LINK = "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/"
BASE_DIR = Path(__file__).resolve().parent.parent
def dl_model(link, model_name, dir_name):
with requests.get(f"{link}{model_name}") as r:
r.raise_for_status()
os.makedirs(os.path.dirname(dir_name / model_name), exist_ok=True)
with open(dir_name / model_name, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
if __name__ == "__main__":
print("Downloading hubert_base.pt...")
dl_model(RVC_DOWNLOAD_LINK, "hubert_base.pt", BASE_DIR / "assets/hubert")
print("Downloading rmvpe.pt...")
dl_model(RVC_DOWNLOAD_LINK, "rmvpe.pt", BASE_DIR / "assets/rmvpe")
print("Downloading vocals.onnx...")
dl_model(
RVC_DOWNLOAD_LINK + "uvr5_weights/onnx_dereverb_By_FoxJoy/",
"vocals.onnx",
BASE_DIR / "assets/uvr5_weights/onnx_dereverb_By_FoxJoy",
)
rvc_models_dir = BASE_DIR / "assets/pretrained"
print("Downloading pretrained models:")
model_names = [
"D32k.pth",
"D40k.pth",
"D48k.pth",
"G32k.pth",
"G40k.pth",
"G48k.pth",
"f0D32k.pth",
"f0D40k.pth",
"f0D48k.pth",
"f0G32k.pth",
"f0G40k.pth",
"f0G48k.pth",
]
for model in model_names:
print(f"Downloading {model}...")
dl_model(RVC_DOWNLOAD_LINK + "pretrained/", model, rvc_models_dir)
rvc_models_dir = BASE_DIR / "assets/pretrained_v2"
print("Downloading pretrained models v2:")
for model in model_names:
print(f"Downloading {model}...")
dl_model(RVC_DOWNLOAD_LINK + "pretrained_v2/", model, rvc_models_dir)
print("Downloading uvr5_weights:")
rvc_models_dir = BASE_DIR / "assets/uvr5_weights"
model_names = [
"HP2-%E4%BA%BA%E5%A3%B0vocals%2B%E9%9D%9E%E4%BA%BA%E5%A3%B0instrumentals.pth",
"HP2_all_vocals.pth",
"HP3_all_vocals.pth",
"HP5-%E4%B8%BB%E6%97%8B%E5%BE%8B%E4%BA%BA%E5%A3%B0vocals%2B%E5%85%B6%E4%BB%96instrumentals.pth",
"HP5_only_main_vocal.pth",
"VR-DeEchoAggressive.pth",
"VR-DeEchoDeReverb.pth",
"VR-DeEchoNormal.pth",
]
for model in model_names:
print(f"Downloading {model}...")
dl_model(RVC_DOWNLOAD_LINK + "uvr5_weights/", model, rvc_models_dir)
print("All models downloaded!")

View File

@@ -1,12 +1,11 @@
from io import BytesIO
import os
import pickle
import sys
import traceback
import logging
logger = logging.getLogger(__name__)
from infer.lib import jit
from infer.lib.jit.get_synthesizer import get_synthesizer
from time import time as ttime
import fairseq
import faiss
import numpy as np
@@ -31,17 +30,16 @@ from multiprocessing import Manager as M
from configs.config import Config
config = Config()
# config = Config()
mm = M()
if config.dml == True:
def forward_dml(ctx, x, scale):
ctx.scale = scale
res = x.clone().detach()
return res
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
def printt(strr, *args):
if len(args) == 0:
print(strr)
else:
print(strr % args)
# config.device=torch.device("cpu")########强制cpu测试
@@ -56,18 +54,27 @@ class RVC:
n_cpu,
inp_q,
opt_q,
device,
config: Config,
last_rvc=None,
) -> None:
"""
初始化
"""
try:
global config
if config.dml == True:
def forward_dml(ctx, x, scale):
ctx.scale = scale
res = x.clone().detach()
return res
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
# global config
self.config = config
self.inp_q = inp_q
self.opt_q = opt_q
# device="cpu"########强制cpu测试
self.device = device
self.device = config.device
self.f0_up_key = key
self.time_step = 160 / 16000 * 1000
self.f0_min = 50
@@ -77,11 +84,14 @@ class RVC:
self.sr = 16000
self.window = 160
self.n_cpu = n_cpu
self.use_jit = self.config.use_jit
self.is_half = config.is_half
if index_rate != 0:
self.index = faiss.read_index(index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
logger.info("Index search enabled")
self.pth_path = pth_path
printt("Index search enabled")
self.pth_path: str = pth_path
self.index_path = index_path
self.index_rate = index_rate
@@ -91,8 +101,8 @@ class RVC:
suffix="",
)
hubert_model = models[0]
hubert_model = hubert_model.to(device)
if config.is_half:
hubert_model = hubert_model.to(self.device)
if self.is_half:
hubert_model = hubert_model.half()
else:
hubert_model = hubert_model.float()
@@ -101,46 +111,80 @@ class RVC:
else:
self.model = last_rvc.model
if last_rvc is None or last_rvc.pth_path != self.pth_path:
cpt = torch.load(self.pth_path, map_location="cpu")
self.net_g: nn.Module = None
def set_default_model():
self.net_g, cpt = get_synthesizer(self.pth_path, self.device)
self.tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
if self.version == "v1":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs256NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif self.version == "v2":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs768NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del self.net_g.enc_q
logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False))
self.net_g.eval().to(device)
# print(2333333333,device,config.device,self.device)#net_g是devicehubert是config.device
if config.is_half:
if self.is_half:
self.net_g = self.net_g.half()
else:
self.net_g = self.net_g.float()
self.is_half = config.is_half
def set_jit_model():
jit_pth_path = self.pth_path.rstrip(".pth")
jit_pth_path += ".half.jit" if self.is_half else ".jit"
reload = False
if str(self.device) == "cuda":
self.device = torch.device("cuda:0")
if os.path.exists(jit_pth_path):
cpt = jit.load(jit_pth_path)
model_device = cpt["device"]
if model_device != str(self.device):
reload = True
else:
reload = True
if reload:
cpt = jit.synthesizer_jit_export(
self.pth_path,
"script",
None,
device=self.device,
is_half=self.is_half,
)
self.tgt_sr = cpt["config"][-1]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
self.net_g = torch.jit.load(
BytesIO(cpt["model"]), map_location=self.device
)
self.net_g.infer = self.net_g.forward
self.net_g.eval().to(self.device)
def set_synthesizer():
if self.use_jit and not config.dml:
if self.is_half and "cpu" in str(self.device):
printt(
"Use default Synthesizer model. \
Jit is not supported on the CPU for half floating point"
)
set_default_model()
else:
set_jit_model()
else:
set_default_model()
if last_rvc is None or last_rvc.pth_path != self.pth_path:
set_synthesizer()
else:
self.tgt_sr = last_rvc.tgt_sr
self.if_f0 = last_rvc.if_f0
self.version = last_rvc.version
self.net_g = last_rvc.net_g
self.is_half = last_rvc.is_half
if last_rvc.use_jit != self.use_jit:
set_synthesizer()
else:
self.net_g = last_rvc.net_g
if last_rvc is not None and hasattr(last_rvc, "model_rmvpe"):
self.model_rmvpe = last_rvc.model_rmvpe
except:
logger.warning(traceback.format_exc())
printt(traceback.format_exc())
def change_key(self, new_key):
self.f0_up_key = new_key
@@ -149,7 +193,7 @@ class RVC:
if new_index_rate != 0 and self.index_rate == 0:
self.index = faiss.read_index(self.index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
logger.info("Index search enabled")
printt("Index search enabled")
self.index_rate = new_index_rate
def get_f0_post(self, f0):
@@ -188,7 +232,7 @@ class RVC:
pad_size = (p_len - len(f0) + 1) // 2
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
# print(pad_size, p_len - len(f0) - pad_size)
# printt(pad_size, p_len - len(f0) - pad_size)
f0 = np.pad(
f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
)
@@ -243,7 +287,7 @@ class RVC:
if "privateuseone" in str(self.device): ###不支持dmlcpu又太慢用不成拿pm顶替
return self.get_f0(x, f0_up_key, 1, "pm")
audio = torch.tensor(np.copy(x))[None].float()
# print("using crepe,device:%s"%self.device)
# printt("using crepe,device:%s"%self.device)
f0, pd = torchcrepe.predict(
audio,
self.sr,
@@ -267,7 +311,7 @@ class RVC:
if hasattr(self, "model_rmvpe") == False:
from infer.lib.rmvpe import RMVPE
logger.info("Loading rmvpe model")
printt("Loading rmvpe model")
self.model_rmvpe = RMVPE(
# "rmvpe.pt", is_half=self.is_half if self.device.type!="privateuseone" else False, device=self.device if self.device.type!="privateuseone"else "cpu"####dml时强制对rmvpe用cpu跑
# "rmvpe.pt", is_half=False, device=self.device####dml配置
@@ -275,6 +319,7 @@ class RVC:
"assets/rmvpe/rmvpe.pt",
is_half=self.is_half,
device=self.device, ####正常逻辑
use_jit=self.config.use_jit,
)
# self.model_rmvpe = RMVPE("aug2_58000_half.pt", is_half=self.is_half, device=self.device)
f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
@@ -292,7 +337,7 @@ class RVC:
f0method,
) -> np.ndarray:
feats = feats.view(1, -1)
if config.is_half:
if self.config.is_half:
feats = feats.half()
else:
feats = feats.float()
@@ -319,17 +364,17 @@ class RVC:
weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
if config.is_half:
if self.config.is_half:
npy = npy.astype("float16")
feats[0][-leng_replace_head:] = (
torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate
+ (1 - self.index_rate) * feats[0][-leng_replace_head:]
)
else:
logger.warning("Index search FAILED or disabled")
printt("Index search FAILED or disabled")
except:
traceback.print_exc()
logger.warning("Index search FAILED")
traceback.printt_exc()
printt("Index search FAILED")
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
t3 = ttime()
if self.if_f0 == 1:
@@ -356,16 +401,21 @@ class RVC:
sid = torch.LongTensor([ii]).to(self.device)
with torch.no_grad():
if self.if_f0 == 1:
# print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
# printt(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
infered_audio = self.net_g.infer(
feats, p_len, cache_pitch, cache_pitchf, sid, rate
feats,
p_len,
cache_pitch,
cache_pitchf,
sid,
torch.FloatTensor([rate]),
)[0][0, 0].data.float()
else:
infered_audio = self.net_g.infer(feats, p_len, sid, rate)[0][
0, 0
].data.float()
infered_audio = self.net_g.infer(
feats, p_len, sid, torch.FloatTensor([rate])
)[0][0, 0].data.float()
t5 = ttime()
logger.info(
printt(
"Spent time: fea = %.2fs, index = %.2fs, f0 = %.2fs, model = %.2fs",
t2 - t1,
t3 - t2,

View File

@@ -1,4 +1,5 @@
import torch
from infer.lib.rmvpe import STFT
from torch.nn.functional import conv1d, conv2d
from typing import Union, Optional
from .utils import linspace, temperature_sigmoid, amp_to_db
@@ -139,17 +140,26 @@ class TorchGate(torch.nn.Module):
are set to 1, and the rest are set to 0.
"""
if xn is not None:
XN = torch.stft(
xn,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(xn.device),
)
if "privateuseone" in str(xn.device):
if not hasattr(self, "stft"):
self.stft = STFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window="hann",
).to(xn.device)
XN = self.stft.transform(xn)
else:
XN = torch.stft(
xn,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(xn.device),
)
XN_db = amp_to_db(XN).to(dtype=X_db.dtype)
else:
XN_db = X_db
@@ -213,16 +223,26 @@ class TorchGate(torch.nn.Module):
"""
# Compute short-time Fourier transform (STFT)
X = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(x.device),
)
if "privateuseone" in str(x.device):
if not hasattr(self, "stft"):
self.stft = STFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window="hann",
).to(x.device)
X, phase = self.stft.transform(x, return_phase=True)
else:
X = torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
return_complex=True,
pad_mode="constant",
center=True,
window=torch.hann_window(self.win_length).to(x.device),
)
# Compute signal mask based on stationary or nonstationary assumptions
if self.nonstationary:
@@ -231,7 +251,7 @@ class TorchGate(torch.nn.Module):
sig_mask = self._stationary_mask(amp_to_db(X), xn)
# Propagate decrease in signal power
sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0
sig_mask = self.prop_decrease * (sig_mask.float() - 1.0) + 1.0
# Smooth signal mask with 2D convolution
if self.smoothing_filter is not None:
@@ -245,13 +265,16 @@ class TorchGate(torch.nn.Module):
Y = X * sig_mask.squeeze(1)
# Inverse STFT to obtain time-domain signal
y = torch.istft(
Y,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=torch.hann_window(self.win_length).to(Y.device),
)
if "privateuseone" in str(Y.device):
y = self.stft.inverse(Y, phase)
else:
y = torch.istft(
Y,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=True,
window=torch.hann_window(self.win_length).to(Y.device),
)
return y.to(dtype=x.dtype)