optimize real-time vc

This commit is contained in:
yxlllc
2023-12-26 00:23:36 +08:00
parent 78f03e7dc0
commit 3dec36568c
6 changed files with 211 additions and 135 deletions

View File

@@ -722,7 +722,8 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
def remove_weight_norm(self):
self.dec.remove_weight_norm()
self.flow.remove_weight_norm()
self.enc_q.remove_weight_norm()
if hasattr(self, "enc_q"):
self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
@@ -783,14 +784,14 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
pitch: torch.Tensor,
nsff0: torch.Tensor,
sid: torch.Tensor,
rate: Optional[torch.Tensor] = None,
skip_head: Optional[torch.Tensor] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
if rate is not None:
assert isinstance(rate, torch.Tensor)
head = int(z_p.shape[2] * (1 - rate.item()))
if skip_head is not None:
assert isinstance(skip_head, torch.Tensor)
head = int(skip_head.item())
z_p = z_p[:, :, head:]
x_mask = x_mask[:, :, head:]
nsff0 = nsff0[:, head:]
@@ -887,7 +888,8 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
def remove_weight_norm(self):
self.dec.remove_weight_norm()
self.flow.remove_weight_norm()
self.enc_q.remove_weight_norm()
if hasattr(self, "enc_q"):
self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
@@ -941,13 +943,14 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
pitch: torch.Tensor,
nsff0: torch.Tensor,
sid: torch.Tensor,
rate: Optional[torch.Tensor] = None,
skip_head: Optional[torch.Tensor] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
if rate is not None:
head = int(z_p.shape[2] * (1.0 - rate.item()))
if skip_head is not None:
assert isinstance(skip_head, torch.Tensor)
head = int(skip_head.item())
z_p = z_p[:, :, head:]
x_mask = x_mask[:, :, head:]
nsff0 = nsff0[:, head:]
@@ -1041,7 +1044,8 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
def remove_weight_norm(self):
self.dec.remove_weight_norm()
self.flow.remove_weight_norm()
self.enc_q.remove_weight_norm()
if hasattr(self, "enc_q"):
self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
@@ -1087,13 +1091,14 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
phone: torch.Tensor,
phone_lengths: torch.Tensor,
sid: torch.Tensor,
rate: Optional[torch.Tensor] = None,
skip_head: Optional[torch.Tensor] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
if rate is not None:
head = int(z_p.shape[2] * (1.0 - rate.item()))
if skip_head is not None:
assert isinstance(skip_head, torch.Tensor)
head = int(skip_head.item())
z_p = z_p[:, :, head:]
x_mask = x_mask[:, :, head:]
z = self.flow(z_p, x_mask, g=g, reverse=True)
@@ -1186,7 +1191,8 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
def remove_weight_norm(self):
self.dec.remove_weight_norm()
self.flow.remove_weight_norm()
self.enc_q.remove_weight_norm()
if hasattr(self, "enc_q"):
self.enc_q.remove_weight_norm()
def __prepare_scriptable__(self):
for hook in self.dec._forward_pre_hooks.values():
@@ -1232,13 +1238,14 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
phone: torch.Tensor,
phone_lengths: torch.Tensor,
sid: torch.Tensor,
rate: Optional[torch.Tensor] = None,
skip_head: Optional[torch.Tensor] = None,
):
g = self.emb_g(sid).unsqueeze(-1)
m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
if rate is not None:
head = int(z_p.shape[2] * (1.0 - rate.item()))
if skip_head is not None:
assert isinstance(skip_head, torch.Tensor)
head = int(skip_head.item())
z_p = z_p[:, :, head:]
x_mask = x_mask[:, :, head:]
z = self.flow(z_p, x_mask, g=g, reverse=True)

View File

@@ -34,4 +34,5 @@ def get_synthesizer(pth_path, device=torch.device("cpu")):
net_g.load_state_dict(cpt["weight"], strict=False)
net_g = net_g.float()
net_g.eval().to(device)
net_g.remove_weight_norm()
return net_g, cpt

View File

@@ -593,16 +593,18 @@ class RMVPE:
def infer_from_audio(self, audio, thred=0.03):
# torch.cuda.synchronize()
t0 = ttime()
# t0 = ttime()
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)
mel = self.mel_extractor(
torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True
audio.float().to(self.device).unsqueeze(0), center=True
)
# print(123123123,mel.device.type)
# torch.cuda.synchronize()
t1 = ttime()
# t1 = ttime()
hidden = self.mel2hidden(mel)
# torch.cuda.synchronize()
t2 = ttime()
# t2 = ttime()
# print(234234,hidden.device.type)
if "privateuseone" not in str(self.device):
hidden = hidden.squeeze(0).cpu().numpy()
@@ -613,7 +615,7 @@ class RMVPE:
f0 = self.decode(hidden, thred=thred)
# torch.cuda.synchronize()
t3 = ttime()
# t3 = ttime()
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
return f0