optimize the streaming f0 exatrators (#1168)

This commit is contained in:
yxlllc
2023-09-02 15:45:50 +08:00
committed by GitHub
parent ad85b02ed9
commit 0fc160c03e
3 changed files with 91 additions and 71 deletions

View File

@@ -2,7 +2,6 @@ import os
import sys
import traceback
import logging
logger = logging.getLogger(__name__)
from time import time as ttime
@@ -48,7 +47,7 @@ if config.dml == True:
# config.is_half=False########强制cpu测试
class RVC:
def __init__(
self, key, pth_path, index_path, index_rate, n_cpu, inp_q, opt_q, device
self, key, pth_path, index_path, index_rate, n_cpu, inp_q, opt_q, device, last_rvc=None,
) -> None:
"""
初始化
@@ -72,48 +71,64 @@ class RVC:
self.index = faiss.read_index(index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
logger.info("Index search enabled")
self.pth_path = pth_path
self.index_path = index_path
self.index_rate = index_rate
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"],
suffix="",
)
hubert_model = models[0]
hubert_model = hubert_model.to(config.device)
if config.is_half:
hubert_model = hubert_model.half()
else:
hubert_model = hubert_model.float()
hubert_model.eval()
self.model = hubert_model
cpt = torch.load(pth_path, map_location="cpu")
self.tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
if self.version == "v1":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs256NSFsid(
*cpt["config"], is_half=config.is_half
)
if last_rvc is None:
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"],
suffix="",
)
hubert_model = models[0]
hubert_model = hubert_model.to(config.device)
if config.is_half:
hubert_model = hubert_model.half()
else:
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif self.version == "v2":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs768NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del self.net_g.enc_q
logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False))
self.net_g.eval().to(device)
# print(2333333333,device,config.device,self.device)#net_g是devicehubert是config.device
if config.is_half:
self.net_g = self.net_g.half()
hubert_model = hubert_model.float()
hubert_model.eval()
self.model = hubert_model
else:
self.net_g = self.net_g.float()
self.is_half = config.is_half
self.model = last_rvc.model
if last_rvc is None or last_rvc.pth_path != self.pth_path:
cpt = torch.load(self.pth_path, map_location="cpu")
self.tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
if self.version == "v1":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs256NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif self.version == "v2":
if self.if_f0 == 1:
self.net_g = SynthesizerTrnMs768NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del self.net_g.enc_q
logger.debug(self.net_g.load_state_dict(cpt["weight"], strict=False))
self.net_g.eval().to(device)
# print(2333333333,device,config.device,self.device)#net_g是devicehubert是config.device
if config.is_half:
self.net_g = self.net_g.half()
else:
self.net_g = self.net_g.float()
self.is_half = config.is_half
else:
self.tgt_sr = last_rvc.tgt_sr
self.if_f0 = last_rvc.if_f0
self.version = last_rvc.version
self.net_g = last_rvc.net_g
self.is_half = last_rvc.is_half
if last_rvc is not None and hasattr(last_rvc, "model_rmvpe"):
self.model_rmvpe = last_rvc.model_rmvpe
except:
logger.warn(traceback.format_exc())
@@ -149,7 +164,7 @@ class RVC:
if method == "rmvpe":
return self.get_f0_rmvpe(x, f0_up_key)
if method == "pm":
p_len = x.shape[0] // 160
p_len = x.shape[0] // 160 + 1
f0 = (
parselmouth.Sound(x, 16000)
.to_pitch_ac(
@@ -181,9 +196,10 @@ class RVC:
f0 = signal.medfilt(f0, 3)
f0 *= pow(2, f0_up_key / 12)
return self.get_f0_post(f0)
f0bak = np.zeros(x.shape[0] // 160, dtype=np.float64)
f0bak = np.zeros(x.shape[0] // 160 + 1, dtype=np.float64)
length = len(x)
part_length = int(length / n_cpu / 160) * 160
part_length = 160 * ((length // 160 - 1) // n_cpu + 1)
n_cpu = (length // 160 - 1) // (part_length // 160) + 1
ts = ttime()
res_f0 = mm.dict()
for idx in range(n_cpu):
@@ -205,7 +221,7 @@ class RVC:
elif idx != n_cpu - 1:
f0 = f0[2:-3]
else:
f0 = f0[2:-1]
f0 = f0[2:]
f0bak[
part_length * idx // 160 : part_length * idx // 160 + f0.shape[0]
] = f0
@@ -259,8 +275,8 @@ class RVC:
self,
feats: torch.Tensor,
indata: np.ndarray,
rate1,
rate2,
block_frame_16k,
rate,
cache_pitch,
cache_pitchf,
f0method,
@@ -286,7 +302,7 @@ class RVC:
t2 = ttime()
try:
if hasattr(self, "index") and self.index_rate != 0:
leng_replace_head = int(rate1 * feats[0].shape[0])
leng_replace_head = int(rate * feats[0].shape[0])
npy = feats[0][-leng_replace_head:].cpu().numpy().astype("float32")
score, ix = self.index.search(npy, k=8)
weight = np.square(1 / score)
@@ -307,9 +323,11 @@ class RVC:
t3 = ttime()
if self.if_f0 == 1:
pitch, pitchf = self.get_f0(indata, self.f0_up_key, self.n_cpu, f0method)
cache_pitch[:] = np.append(cache_pitch[pitch[:-1].shape[0] :], pitch[:-1])
start_frame = block_frame_16k // 160
end_frame = len(cache_pitch) - (pitch.shape[0] - 4) + start_frame
cache_pitch[:] = np.append(cache_pitch[start_frame : end_frame], pitch[3:-1])
cache_pitchf[:] = np.append(
cache_pitchf[pitchf[:-1].shape[0] :], pitchf[:-1]
cache_pitchf[start_frame : end_frame], pitchf[3:-1]
)
p_len = min(feats.shape[1], 13000, cache_pitch.shape[0])
else:
@@ -330,14 +348,14 @@ class RVC:
# print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2)
infered_audio = (
self.net_g.infer(
feats, p_len, cache_pitch, cache_pitchf, sid, rate2
feats, p_len, cache_pitch, cache_pitchf, sid, rate
)[0][0, 0]
.data.cpu()
.float()
)
else:
infered_audio = (
self.net_g.infer(feats, p_len, sid, rate2)[0][0, 0]
self.net_g.infer(feats, p_len, sid, rate)[0][0, 0]
.data.cpu()
.float()
)