mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2026-01-19 18:41:52 +00:00
chore(sync): merge dev into main (#1379)
* Optimize latency (#1259) * add attribute: configs/config.py Optimize latency: tools/rvc_for_realtime.py * new file: assets/Synthesizer_inputs.pth * fix: configs/config.py fix: tools/rvc_for_realtime.py * fix bug: infer/lib/infer_pack/models.py * new file: assets/hubert_inputs.pth new file: assets/rmvpe_inputs.pth modified: configs/config.py new features: infer/lib/rmvpe.py new features: tools/jit_export/__init__.py new features: tools/jit_export/get_hubert.py new features: tools/jit_export/get_rmvpe.py new features: tools/jit_export/get_synthesizer.py optimize: tools/rvc_for_realtime.py * optimize: tools/jit_export/get_synthesizer.py fix bug: tools/jit_export/__init__.py * Fixed a bug caused by using half on the CPU: infer/lib/rmvpe.py Fixed a bug caused by using half on the CPU: tools/jit_export/__init__.py Fixed CIRCULAR IMPORT: tools/jit_export/get_rmvpe.py Fixed CIRCULAR IMPORT: tools/jit_export/get_synthesizer.py Fixed a bug caused by using half on the CPU: tools/rvc_for_realtime.py * Remove useless code: infer/lib/rmvpe.py * Delete gui_v1 copy.py * Delete .vscode/launch.json * Delete jit_export_test.py * Delete tools/rvc_for_realtime copy.py * Delete configs/config.json * Delete .gitignore * Fix exceptions caused by switching inference devices: infer/lib/rmvpe.py Fix exceptions caused by switching inference devices: tools/jit_export/__init__.py Fix exceptions caused by switching inference devices: tools/rvc_for_realtime.py * restore * replace(you can undo this commit) * remove debug_print --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * Fixed some bugs when exporting ONNX model (#1254) * fix import (#1280) * fix import * lint * 🎨 同步 locale (#1242) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Fix jit load and import issue (#1282) * fix jit model loading : infer/lib/rmvpe.py * modified: assets/hubert/.gitignore move file: assets/hubert_inputs.pth -> assets/hubert/hubert_inputs.pth modified: assets/rmvpe/.gitignore move file: assets/rmvpe_inputs.pth -> assets/rmvpe/rmvpe_inputs.pth fix import: gui_v1.py * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * Add input wav and delay time monitor for real-time gui (#1293) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * add input wav and delay time monitor --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> * Optimize latency using scripted jit (#1291) * feat(workflow): trigger on dev * feat(workflow): add close-pr on non-dev branch * 🎨 同步 locale (#1289) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: edit PR template * Optimize-latency-using-scripted: configs/config.py Optimize-latency-using-scripted: infer/lib/infer_pack/attentions.py Optimize-latency-using-scripted: infer/lib/infer_pack/commons.py Optimize-latency-using-scripted: infer/lib/infer_pack/models.py Optimize-latency-using-scripted: infer/lib/infer_pack/modules.py Optimize-latency-using-scripted: infer/lib/jit/__init__.py Optimize-latency-using-scripted: infer/lib/jit/get_hubert.py Optimize-latency-using-scripted: infer/lib/jit/get_rmvpe.py Optimize-latency-using-scripted: infer/lib/jit/get_synthesizer.py Optimize-latency-using-scripted: infer/lib/rmvpe.py Optimize-latency-using-scripted: tools/rvc_for_realtime.py * modified: infer/lib/infer_pack/models.py * fix some bug: configs/config.py fix some bug: infer/lib/infer_pack/models.py fix some bug: infer/lib/rmvpe.py * Fixed abnormal reference of logger in multiprocessing: infer/modules/train/train.py --------- Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Format code (#1298) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * 🎨 同步 locale (#1299) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: optimize actions * feat(workflow): add sync dev * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: optimize actions * feat: add jit options (#1303) Delete useless code: infer/lib/jit/get_synthesizer.py Optimized code: tools/rvc_for_realtime.py * Code refactor + re-design inference ui (#1304) * Code refacor + re-design inference ui * Fix tabname * i18n jp --------- Co-authored-by: Ftps <ftpsflandre@gmail.com> * feat: optimize actions * feat: optimize actions * Update README & en_US locale file (#1309) * critical: some bug fixes (#1322) * JIT acceleration switch does not support hot update * fix padding bug of rmvpe in torch-directml * fix padding bug of rmvpe in torch-directml * Fix STFT under torch_directml (#1330) * chore(format): run black on dev (#1318) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(i18n): sync locale on dev (#1317) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * feat: allow for tta to be passed to uvr (#1361) * chore(format): run black on dev (#1373) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Added script for automatically download all needed models at install (#1366) * Delete modules.py * Add files via upload * Add files via upload * Add files via upload * Add files via upload * chore(i18n): sync locale on dev (#1377) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * chore(format): run black on dev (#1376) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Update IPEX library (#1362) * Update IPEX library * Update ipex index * chore(format): run black on dev (#1378) Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: Chengjia Jiang <46401978+ChasonJiang@users.noreply.github.com> Co-authored-by: Ftps <ftpsflandre@gmail.com> Co-authored-by: shizuku_nia <102004222+ShizukuNia@users.noreply.github.com> Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> Co-authored-by: yxlllc <33565655+yxlllc@users.noreply.github.com> Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Co-authored-by: Blaise <133521603+blaise-tk@users.noreply.github.com> Co-authored-by: Rice Cake <gak141808@gmail.com> Co-authored-by: AWAS666 <33494149+AWAS666@users.noreply.github.com> Co-authored-by: Dmitry <nda2911@yandex.ru> Co-authored-by: Disty0 <47277141+Disty0@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
fe166e7f3d
commit
e9dd11bddb
79
tools/download_models.py
Normal file
79
tools/download_models.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import requests
|
||||
|
||||
RVC_DOWNLOAD_LINK = "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/"
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def dl_model(link, model_name, dir_name):
|
||||
with requests.get(f"{link}{model_name}") as r:
|
||||
r.raise_for_status()
|
||||
os.makedirs(os.path.dirname(dir_name / model_name), exist_ok=True)
|
||||
with open(dir_name / model_name, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Downloading hubert_base.pt...")
|
||||
dl_model(RVC_DOWNLOAD_LINK, "hubert_base.pt", BASE_DIR / "assets/hubert")
|
||||
print("Downloading rmvpe.pt...")
|
||||
dl_model(RVC_DOWNLOAD_LINK, "rmvpe.pt", BASE_DIR / "assets/rmvpe")
|
||||
print("Downloading vocals.onnx...")
|
||||
dl_model(
|
||||
RVC_DOWNLOAD_LINK + "uvr5_weights/onnx_dereverb_By_FoxJoy/",
|
||||
"vocals.onnx",
|
||||
BASE_DIR / "assets/uvr5_weights/onnx_dereverb_By_FoxJoy",
|
||||
)
|
||||
|
||||
rvc_models_dir = BASE_DIR / "assets/pretrained"
|
||||
|
||||
print("Downloading pretrained models:")
|
||||
|
||||
model_names = [
|
||||
"D32k.pth",
|
||||
"D40k.pth",
|
||||
"D48k.pth",
|
||||
"G32k.pth",
|
||||
"G40k.pth",
|
||||
"G48k.pth",
|
||||
"f0D32k.pth",
|
||||
"f0D40k.pth",
|
||||
"f0D48k.pth",
|
||||
"f0G32k.pth",
|
||||
"f0G40k.pth",
|
||||
"f0G48k.pth",
|
||||
]
|
||||
for model in model_names:
|
||||
print(f"Downloading {model}...")
|
||||
dl_model(RVC_DOWNLOAD_LINK + "pretrained/", model, rvc_models_dir)
|
||||
|
||||
rvc_models_dir = BASE_DIR / "assets/pretrained_v2"
|
||||
|
||||
print("Downloading pretrained models v2:")
|
||||
|
||||
for model in model_names:
|
||||
print(f"Downloading {model}...")
|
||||
dl_model(RVC_DOWNLOAD_LINK + "pretrained_v2/", model, rvc_models_dir)
|
||||
|
||||
print("Downloading uvr5_weights:")
|
||||
|
||||
rvc_models_dir = BASE_DIR / "assets/uvr5_weights"
|
||||
|
||||
model_names = [
|
||||
"HP2-%E4%BA%BA%E5%A3%B0vocals%2B%E9%9D%9E%E4%BA%BA%E5%A3%B0instrumentals.pth",
|
||||
"HP2_all_vocals.pth",
|
||||
"HP3_all_vocals.pth",
|
||||
"HP5-%E4%B8%BB%E6%97%8B%E5%BE%8B%E4%BA%BA%E5%A3%B0vocals%2B%E5%85%B6%E4%BB%96instrumentals.pth",
|
||||
"HP5_only_main_vocal.pth",
|
||||
"VR-DeEchoAggressive.pth",
|
||||
"VR-DeEchoDeReverb.pth",
|
||||
"VR-DeEchoNormal.pth",
|
||||
]
|
||||
for model in model_names:
|
||||
print(f"Downloading {model}...")
|
||||
dl_model(RVC_DOWNLOAD_LINK + "uvr5_weights/", model, rvc_models_dir)
|
||||
|
||||
print("All models downloaded!")
|
||||
@@ -1,12 +1,11 @@
|
||||
from io import BytesIO
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from infer.lib import jit
|
||||
from infer.lib.jit.get_synthesizer import get_synthesizer
|
||||
from time import time as ttime
|
||||
|
||||
import fairseq
|
||||
import faiss
|
||||
import numpy as np
|
||||
@@ -31,17 +30,16 @@ from multiprocessing import Manager as M
|
||||
|
||||
from configs.config import Config
|
||||
|
||||
config = Config()
|
||||
# config = Config()
|
||||
|
||||
mm = M()
|
||||
if config.dml == True:
|
||||
|
||||
def forward_dml(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.clone().detach()
|
||||
return res
|
||||
|
||||
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
|
||||
def printt(strr, *args):
|
||||
if len(args) == 0:
|
||||
print(strr)
|
||||
else:
|
||||
print(strr % args)
|
||||
|
||||
|
||||
# config.device=torch.device("cpu")########强制cpu测试
|
||||
@@ -56,18 +54,27 @@ class RVC:
|
||||
n_cpu,
|
||||
inp_q,
|
||||
opt_q,
|
||||
device,
|
||||
config: Config,
|
||||
last_rvc=None,
|
||||
) -> None:
|
||||
"""
|
||||
初始化
|
||||
"""
|
||||
try:
|
||||
global config
|
||||
if config.dml == True:
|
||||
|
||||
def forward_dml(ctx, x, scale):
|
||||
ctx.scale = scale
|
||||
res = x.clone().detach()
|
||||
return res
|
||||
|
||||
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
|
||||
# global config
|
||||
self.config = config
|
||||
self.inp_q = inp_q
|
||||
self.opt_q = opt_q
|
||||
# device="cpu"########强制cpu测试
|
||||
self.device = device
|
||||
self.device = config.device
|
||||
self.f0_up_key = key
|
||||
self.time_step = 160 / 16000 * 1000
|
||||
self.f0_min = 50
|
||||
@@ -77,11 +84,14 @@ class RVC:
|
||||
self.sr = 16000
|
||||
self.window = 160
|
||||
self.n_cpu = n_cpu
|
||||
self.use_jit = self.config.use_jit
|
||||
self.is_half = config.is_half
|
||||
|
||||
if index_rate != 0:
|
||||
self.index = faiss.read_index(index_path)
|
||||
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
|
||||
logger.info("Index search enabled")
|
||||
self.pth_path = pth_path
|
||||
printt("Index search enabled")
|
||||
self.pth_path: str = pth_path
|
||||
self.index_path = index_path
|
||||
self.index_rate = index_rate
|
||||
|
||||
@@ -91,8 +101,8 @@ class RVC:
|
||||
suffix="",
|
||||
)
|
||||
hubert_model = models[0]
|
||||
hubert_model = hubert_model.to(device)
|
||||
if config.is_half:
|
||||
hubert_model = hubert_model.to(self.device)
|
||||
if self.is_half:
|
||||
hubert_model = hubert_model.half()
|
||||
else:
|
||||
hubert_model = hubert_model.float()
|
||||
@@ -101,46 +111,80 @@ class RVC:
|
||||
else:
|
||||
self.model = last_rvc.model
|
||||
|
||||
if last_rvc is None or last_rvc.pth_path != self.pth_path:
|
||||
cpt = torch.load(self.pth_path, map_location="cpu")
|
||||
self.net_g: nn.Module = None
|
||||
|
||||
def set_default_model():
|
||||
self.net_g, cpt = get_synthesizer(self.pth_path, self.device)
|
||||
self.tgt_sr = cpt["config"][-1]
|
||||
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
||||
self.if_f0 = cpt.get("f0", 1)
|
||||
self.version = cpt.get("version", "v1")
|
||||
if self.version == "v1":
|
||||
if self.if_f0 == 1:
|
||||
self.net_g = SynthesizerTrnMs256NSFsid(
|
||||
*cpt["config"], is_half=config.is_half
|
||||
)
|
||||
else:
|
||||
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
||||
elif self.version == "v2":
|
||||
if self.if_f0 == 1:
|
||||
self.net_g = SynthesizerTrnMs768NSFsid(
|
||||
*cpt["config"], is_half=config.is_half
|
||||
)
|
||||
else:
|
||||
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
||||
del self.net_g.enc_q
|
||||
logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False))
|
||||
self.net_g.eval().to(device)
|
||||
# print(2333333333,device,config.device,self.device)#net_g是device,hubert是config.device
|
||||
if config.is_half:
|
||||
if self.is_half:
|
||||
self.net_g = self.net_g.half()
|
||||
else:
|
||||
self.net_g = self.net_g.float()
|
||||
self.is_half = config.is_half
|
||||
|
||||
def set_jit_model():
|
||||
jit_pth_path = self.pth_path.rstrip(".pth")
|
||||
jit_pth_path += ".half.jit" if self.is_half else ".jit"
|
||||
reload = False
|
||||
if str(self.device) == "cuda":
|
||||
self.device = torch.device("cuda:0")
|
||||
if os.path.exists(jit_pth_path):
|
||||
cpt = jit.load(jit_pth_path)
|
||||
model_device = cpt["device"]
|
||||
if model_device != str(self.device):
|
||||
reload = True
|
||||
else:
|
||||
reload = True
|
||||
|
||||
if reload:
|
||||
cpt = jit.synthesizer_jit_export(
|
||||
self.pth_path,
|
||||
"script",
|
||||
None,
|
||||
device=self.device,
|
||||
is_half=self.is_half,
|
||||
)
|
||||
|
||||
self.tgt_sr = cpt["config"][-1]
|
||||
self.if_f0 = cpt.get("f0", 1)
|
||||
self.version = cpt.get("version", "v1")
|
||||
self.net_g = torch.jit.load(
|
||||
BytesIO(cpt["model"]), map_location=self.device
|
||||
)
|
||||
self.net_g.infer = self.net_g.forward
|
||||
self.net_g.eval().to(self.device)
|
||||
|
||||
def set_synthesizer():
|
||||
if self.use_jit and not config.dml:
|
||||
if self.is_half and "cpu" in str(self.device):
|
||||
printt(
|
||||
"Use default Synthesizer model. \
|
||||
Jit is not supported on the CPU for half floating point"
|
||||
)
|
||||
set_default_model()
|
||||
else:
|
||||
set_jit_model()
|
||||
else:
|
||||
set_default_model()
|
||||
|
||||
if last_rvc is None or last_rvc.pth_path != self.pth_path:
|
||||
set_synthesizer()
|
||||
else:
|
||||
self.tgt_sr = last_rvc.tgt_sr
|
||||
self.if_f0 = last_rvc.if_f0
|
||||
self.version = last_rvc.version
|
||||
self.net_g = last_rvc.net_g
|
||||
self.is_half = last_rvc.is_half
|
||||
if last_rvc.use_jit != self.use_jit:
|
||||
set_synthesizer()
|
||||
else:
|
||||
self.net_g = last_rvc.net_g
|
||||
|
||||
if last_rvc is not None and hasattr(last_rvc, "model_rmvpe"):
|
||||
self.model_rmvpe = last_rvc.model_rmvpe
|
||||
except:
|
||||
logger.warning(traceback.format_exc())
|
||||
printt(traceback.format_exc())
|
||||
|
||||
def change_key(self, new_key):
|
||||
self.f0_up_key = new_key
|
||||
@@ -149,7 +193,7 @@ class RVC:
|
||||
if new_index_rate != 0 and self.index_rate == 0:
|
||||
self.index = faiss.read_index(self.index_path)
|
||||
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
|
||||
logger.info("Index search enabled")
|
||||
printt("Index search enabled")
|
||||
self.index_rate = new_index_rate
|
||||
|
||||
def get_f0_post(self, f0):
|
||||
@@ -188,7 +232,7 @@ class RVC:
|
||||
|
||||
pad_size = (p_len - len(f0) + 1) // 2
|
||||
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
||||
# print(pad_size, p_len - len(f0) - pad_size)
|
||||
# printt(pad_size, p_len - len(f0) - pad_size)
|
||||
f0 = np.pad(
|
||||
f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
|
||||
)
|
||||
@@ -243,7 +287,7 @@ class RVC:
|
||||
if "privateuseone" in str(self.device): ###不支持dml,cpu又太慢用不成,拿pm顶替
|
||||
return self.get_f0(x, f0_up_key, 1, "pm")
|
||||
audio = torch.tensor(np.copy(x))[None].float()
|
||||
# print("using crepe,device:%s"%self.device)
|
||||
# printt("using crepe,device:%s"%self.device)
|
||||
f0, pd = torchcrepe.predict(
|
||||
audio,
|
||||
self.sr,
|
||||
@@ -267,7 +311,7 @@ class RVC:
|
||||
if hasattr(self, "model_rmvpe") == False:
|
||||
from infer.lib.rmvpe import RMVPE
|
||||
|
||||
logger.info("Loading rmvpe model")
|
||||
printt("Loading rmvpe model")
|
||||
self.model_rmvpe = RMVPE(
|
||||
# "rmvpe.pt", is_half=self.is_half if self.device.type!="privateuseone" else False, device=self.device if self.device.type!="privateuseone"else "cpu"####dml时强制对rmvpe用cpu跑
|
||||
# "rmvpe.pt", is_half=False, device=self.device####dml配置
|
||||
@@ -275,6 +319,7 @@ class RVC:
|
||||
"assets/rmvpe/rmvpe.pt",
|
||||
is_half=self.is_half,
|
||||
device=self.device, ####正常逻辑
|
||||
use_jit=self.config.use_jit,
|
||||
)
|
||||
# self.model_rmvpe = RMVPE("aug2_58000_half.pt", is_half=self.is_half, device=self.device)
|
||||
f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
|
||||
@@ -292,7 +337,7 @@ class RVC:
|
||||
f0method,
|
||||
) -> np.ndarray:
|
||||
feats = feats.view(1, -1)
|
||||
if config.is_half:
|
||||
if self.config.is_half:
|
||||
feats = feats.half()
|
||||
else:
|
||||
feats = feats.float()
|
||||
@@ -319,17 +364,17 @@ class RVC:
|
||||
weight = np.square(1 / score)
|
||||
weight /= weight.sum(axis=1, keepdims=True)
|
||||
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
||||
if config.is_half:
|
||||
if self.config.is_half:
|
||||
npy = npy.astype("float16")
|
||||
feats[0][-leng_replace_head:] = (
|
||||
torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate
|
||||
+ (1 - self.index_rate) * feats[0][-leng_replace_head:]
|
||||
)
|
||||
else:
|
||||
logger.warning("Index search FAILED or disabled")
|
||||
printt("Index search FAILED or disabled")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
logger.warning("Index search FAILED")
|
||||
traceback.printt_exc()
|
||||
printt("Index search FAILED")
|
||||
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
||||
t3 = ttime()
|
||||
if self.if_f0 == 1:
|
||||
@@ -356,16 +401,21 @@ class RVC:
|
||||
sid = torch.LongTensor([ii]).to(self.device)
|
||||
with torch.no_grad():
|
||||
if self.if_f0 == 1:
|
||||
# print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
|
||||
# printt(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
|
||||
infered_audio = self.net_g.infer(
|
||||
feats, p_len, cache_pitch, cache_pitchf, sid, rate
|
||||
feats,
|
||||
p_len,
|
||||
cache_pitch,
|
||||
cache_pitchf,
|
||||
sid,
|
||||
torch.FloatTensor([rate]),
|
||||
)[0][0, 0].data.float()
|
||||
else:
|
||||
infered_audio = self.net_g.infer(feats, p_len, sid, rate)[0][
|
||||
0, 0
|
||||
].data.float()
|
||||
infered_audio = self.net_g.infer(
|
||||
feats, p_len, sid, torch.FloatTensor([rate])
|
||||
)[0][0, 0].data.float()
|
||||
t5 = ttime()
|
||||
logger.info(
|
||||
printt(
|
||||
"Spent time: fea = %.2fs, index = %.2fs, f0 = %.2fs, model = %.2fs",
|
||||
t2 - t1,
|
||||
t3 - t2,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from infer.lib.rmvpe import STFT
|
||||
from torch.nn.functional import conv1d, conv2d
|
||||
from typing import Union, Optional
|
||||
from .utils import linspace, temperature_sigmoid, amp_to_db
|
||||
@@ -139,17 +140,26 @@ class TorchGate(torch.nn.Module):
|
||||
are set to 1, and the rest are set to 0.
|
||||
"""
|
||||
if xn is not None:
|
||||
XN = torch.stft(
|
||||
xn,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
return_complex=True,
|
||||
pad_mode="constant",
|
||||
center=True,
|
||||
window=torch.hann_window(self.win_length).to(xn.device),
|
||||
)
|
||||
|
||||
if "privateuseone" in str(xn.device):
|
||||
if not hasattr(self, "stft"):
|
||||
self.stft = STFT(
|
||||
filter_length=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window="hann",
|
||||
).to(xn.device)
|
||||
XN = self.stft.transform(xn)
|
||||
else:
|
||||
XN = torch.stft(
|
||||
xn,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
return_complex=True,
|
||||
pad_mode="constant",
|
||||
center=True,
|
||||
window=torch.hann_window(self.win_length).to(xn.device),
|
||||
)
|
||||
XN_db = amp_to_db(XN).to(dtype=X_db.dtype)
|
||||
else:
|
||||
XN_db = X_db
|
||||
@@ -213,16 +223,26 @@ class TorchGate(torch.nn.Module):
|
||||
"""
|
||||
|
||||
# Compute short-time Fourier transform (STFT)
|
||||
X = torch.stft(
|
||||
x,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
return_complex=True,
|
||||
pad_mode="constant",
|
||||
center=True,
|
||||
window=torch.hann_window(self.win_length).to(x.device),
|
||||
)
|
||||
if "privateuseone" in str(x.device):
|
||||
if not hasattr(self, "stft"):
|
||||
self.stft = STFT(
|
||||
filter_length=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window="hann",
|
||||
).to(x.device)
|
||||
X, phase = self.stft.transform(x, return_phase=True)
|
||||
else:
|
||||
X = torch.stft(
|
||||
x,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
return_complex=True,
|
||||
pad_mode="constant",
|
||||
center=True,
|
||||
window=torch.hann_window(self.win_length).to(x.device),
|
||||
)
|
||||
|
||||
# Compute signal mask based on stationary or nonstationary assumptions
|
||||
if self.nonstationary:
|
||||
@@ -231,7 +251,7 @@ class TorchGate(torch.nn.Module):
|
||||
sig_mask = self._stationary_mask(amp_to_db(X), xn)
|
||||
|
||||
# Propagate decrease in signal power
|
||||
sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0
|
||||
sig_mask = self.prop_decrease * (sig_mask.float() - 1.0) + 1.0
|
||||
|
||||
# Smooth signal mask with 2D convolution
|
||||
if self.smoothing_filter is not None:
|
||||
@@ -245,13 +265,16 @@ class TorchGate(torch.nn.Module):
|
||||
Y = X * sig_mask.squeeze(1)
|
||||
|
||||
# Inverse STFT to obtain time-domain signal
|
||||
y = torch.istft(
|
||||
Y,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
center=True,
|
||||
window=torch.hann_window(self.win_length).to(Y.device),
|
||||
)
|
||||
if "privateuseone" in str(Y.device):
|
||||
y = self.stft.inverse(Y, phase)
|
||||
else:
|
||||
y = torch.istft(
|
||||
Y,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
center=True,
|
||||
window=torch.hann_window(self.win_length).to(Y.device),
|
||||
)
|
||||
|
||||
return y.to(dtype=x.dtype)
|
||||
|
||||
Reference in New Issue
Block a user