optimize the streaming f0 exatrators (#1168)

This commit is contained in:
yxlllc
2023-09-02 15:45:50 +08:00
committed by GitHub
parent ad85b02ed9
commit 0fc160c03e
3 changed files with 91 additions and 71 deletions

View File

@@ -261,9 +261,9 @@ if __name__ == "__main__":
[
sg.Text(i18n("采样长度")),
sg.Slider(
range=(0.09, 2.4),
range=(0.05, 2.4),
key="block_time",
resolution=0.03,
resolution=0.01,
orientation="h",
default_value=data.get("block_time", ""),
enable_events=True,
@@ -455,18 +455,20 @@ if __name__ == "__main__":
inp_q,
opt_q,
device,
self.rvc if hasattr(self, "rvc") else None
)
self.config.samplerate = self.rvc.tgt_sr
self.config.crossfade_time = min(
self.config.crossfade_time, self.config.block_time
)
self.block_frame = int(self.config.block_time * self.config.samplerate)
self.zc = self.rvc.tgt_sr // 100
self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc
self.block_frame_16k = 160 * self.block_frame // self.zc
self.crossfade_frame = int(
self.config.crossfade_time * self.config.samplerate
)
self.sola_search_frame = int(0.01 * self.config.samplerate)
self.extra_frame = int(self.config.extra_time * self.config.samplerate)
self.zc = self.rvc.tgt_sr // 100
self.input_wav: np.ndarray = np.zeros(
int(
np.ceil(
@@ -482,6 +484,7 @@ if __name__ == "__main__":
),
dtype="float32",
)
self.input_wav_res: torch.Tensor= torch.zeros(160 * len(self.input_wav) // self.zc)
self.output_wav_cache: torch.Tensor = torch.zeros(
int(
np.ceil(
@@ -573,18 +576,14 @@ if __name__ == "__main__":
for i in range(db_threhold.shape[0]):
if db_threhold[i]:
indata[i * hop_length : (i + 1) * hop_length] = 0
self.input_wav[:] = np.append(self.input_wav[self.block_frame :], indata)
self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :]
self.input_wav[-self.block_frame: ] = indata
# infer
inp = torch.from_numpy(self.input_wav).to(device)
res1 = self.resampler(inp)
###55%
rate1 = self.block_frame / (
self.extra_frame
+ self.crossfade_frame
+ self.sola_search_frame
+ self.block_frame
)
rate2 = (
inp = torch.from_numpy(self.input_wav[-self.block_frame-2*self.zc :]).to(device)
self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone()
self.input_wav_res[-self.block_frame_16k-160 :] = self.resampler(inp)[160 :]
rate = (
self.crossfade_frame + self.sola_search_frame + self.block_frame
) / (
self.extra_frame
@@ -592,11 +591,14 @@ if __name__ == "__main__":
+ self.sola_search_frame
+ self.block_frame
)
f0_extractor_frame = self.block_frame_16k + 800
if self.config.f0method == 'rmvpe':
f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1)
res2 = self.rvc.infer(
res1,
res1[-self.block_frame :].cpu().numpy(),
rate1,
rate2,
self.input_wav_res,
self.input_wav_res[-f0_extractor_frame :].cpu().numpy(),
self.block_frame_16k,
rate,
self.pitch,
self.pitchf,
self.config.f0method,