22import torch
33from multiprocessing import cpu_count
44
5+ def config_file_change_fp32 ():
6+ for config_file in ["32k.json" , "40k.json" , "48k.json" ]:
7+ with open (f"configs/{ config_file } " , "r" ) as f :
8+ strr = f .read ().replace ("true" , "false" )
9+ with open (f"configs/{ config_file } " , "w" ) as f :
10+ f .write (strr )
11+ with open ("trainset_preprocess_pipeline_print.py" , "r" ) as f :
12+ strr = f .read ().replace ("3.7" , "3.0" )
13+ with open ("trainset_preprocess_pipeline_print.py" , "w" ) as f :
14+ f .write (strr )
515
616class Config :
717 def __init__ (self ):
@@ -60,15 +70,7 @@ def device_config(self) -> tuple:
6070 ):
6171 print ("16系/10系显卡和P40强制单精度" )
6272 self .is_half = False
63- for config_file in ["32k.json" , "40k.json" , "48k.json" ]:
64- with open (f"configs/{ config_file } " , "r" ) as f :
65- strr = f .read ().replace ("true" , "false" )
66- with open (f"configs/{ config_file } " , "w" ) as f :
67- f .write (strr )
68- with open ("trainset_preprocess_pipeline_print.py" , "r" ) as f :
69- strr = f .read ().replace ("3.7" , "3.0" )
70- with open ("trainset_preprocess_pipeline_print.py" , "w" ) as f :
71- f .write (strr )
73+ config_file_change_fp32 ()
7274 else :
7375 self .gpu_name = None
7476 self .gpu_mem = int (
@@ -87,10 +89,12 @@ def device_config(self) -> tuple:
8789 print ("没有发现支持的N卡, 使用MPS进行推理" )
8890 self .device = "mps"
8991 self .is_half = False
92+ config_file_change_fp32 ()
9093 else :
9194 print ("没有发现支持的N卡, 使用CPU进行推理" )
9295 self .device = "cpu"
9396 self .is_half = False
97+ config_file_change_fp32 ()
9498
9599 if self .n_cpu == 0 :
96100 self .n_cpu = cpu_count ()
0 commit comments