optimize: no json config writeback

This commit is contained in:
源文雨
2023-09-02 13:53:56 +08:00
parent 3f78b73ec7
commit ad85b02ed9
8 changed files with 85 additions and 61 deletions

View File

@@ -1,6 +1,7 @@
import argparse
import os
import sys
import json
from multiprocessing import cpu_count
import torch
@@ -10,23 +11,13 @@ import logging
logger = logging.getLogger(__name__)
def use_fp32_config():
for config_file in [
"v1/32k.json",
"v1/40k.json",
"v1/48k.json",
"v2/48k.json",
"v2/32k.json",
]:
with open(f"configs/{config_file}", "r") as f:
strr = f.read().replace("true", "false")
with open(f"configs/{config_file}", "w") as f:
f.write(strr)
with open("infer/modules/train/preprocess.py", "r") as f:
strr = f.read().replace("3.7", "3.0")
with open("infer/modules/train/preprocess.py", "w") as f:
f.write(strr)
version_config_list = [
"v1/32k.json",
"v1/40k.json",
"v1/48k.json",
"v2/48k.json",
"v2/32k.json",
]
def singleton_variable(func):
def wrapper(*args, **kwargs):
@@ -45,6 +36,7 @@ class Config:
self.is_half = True
self.n_cpu = 0
self.gpu_name = None
self.json_config = self.load_config_json()
self.gpu_mem = None
(
self.python_cmd,
@@ -57,6 +49,14 @@ class Config:
self.instead = ""
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
@staticmethod
def load_config_json() -> dict:
d = {}
for config_file in version_config_list:
with open(f"configs/{config_file}", "r") as f:
d[config_file] = json.load(f)
return d
@staticmethod
def arg_parse() -> tuple:
exe = sys.executable or "python"
@@ -101,6 +101,10 @@ class Config:
return True
except Exception:
return False
def use_fp32_config(self):
for config_file in version_config_list:
self.json_config[config_file]["train"]["fp16_run"] = False
def device_config(self) -> tuple:
if torch.cuda.is_available():
@@ -116,7 +120,7 @@ class Config:
):
logger.info("Found GPU %s, force to fp32", self.gpu_name)
self.is_half = False
use_fp32_config()
self.use_fp32_config()
else:
logger.info("Found GPU %s", self.gpu_name)
self.gpu_mem = int(
@@ -135,12 +139,12 @@ class Config:
logger.info("No supported Nvidia GPU found")
self.device = self.instead = "mps"
self.is_half = False
use_fp32_config()
self.use_fp32_config()
else:
logger.info("No supported Nvidia GPU found")
self.device = self.instead = "cpu"
self.is_half = False
use_fp32_config()
self.use_fp32_config()
if self.n_cpu == 0:
self.n_cpu = cpu_count()