|
| 1 | +''' |
| 2 | +0416后的更新: |
| 3 | + 引入config中half |
| 4 | + 重建npy而不用填写 |
| 5 | + v2支持 |
| 6 | + 无f0模型支持 |
| 7 | + 修复 |
| 8 | +
|
| 9 | + int16: |
| 10 | + 增加无索引支持 |
| 11 | + f0算法改harvest(怎么看就只有这个会影响CPU占用),但是不这么改效果不好 |
| 12 | +''' |
1 | 13 | import os, sys, traceback |
2 | | - |
3 | 14 | now_dir = os.getcwd() |
4 | 15 | sys.path.append(now_dir) |
| 16 | +from config import Config |
| 17 | +is_half=Config().is_half |
5 | 18 | import PySimpleGUI as sg |
6 | 19 | import sounddevice as sd |
7 | 20 | import noisereduce as nr |
|
13 | 26 | import scipy.signal as signal |
14 | 27 |
|
15 | 28 | # import matplotlib.pyplot as plt |
16 | | -from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono |
| 29 | +from infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono,SynthesizerTrnMs768NSFsid,SynthesizerTrnMs768NSFsid_nono |
17 | 30 | from i18n import I18nAuto |
18 | 31 |
|
19 | 32 | i18n = I18nAuto() |
@@ -50,20 +63,33 @@ def __init__( |
50 | 63 | ) |
51 | 64 | self.model = models[0] |
52 | 65 | self.model = self.model.to(device) |
53 | | - self.model = self.model.half() |
| 66 | + if(is_half==True): |
| 67 | + self.model = self.model.half() |
| 68 | + else: |
| 69 | + self.model = self.model.float() |
54 | 70 | self.model.eval() |
55 | 71 | cpt = torch.load(pth_path, map_location="cpu") |
56 | 72 | self.tgt_sr = cpt["config"][-1] |
57 | 73 | cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk |
58 | 74 | self.if_f0 = cpt.get("f0", 1) |
59 | | - if self.if_f0 == 1: |
60 | | - self.net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=True) |
61 | | - else: |
62 | | - self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) |
| 75 | + self.version = cpt.get("version", "v1") |
| 76 | + if version == "v1": |
| 77 | + if if_f0 == 1: |
| 78 | + net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half) |
| 79 | + else: |
| 80 | + net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) |
| 81 | + elif version == "v2": |
| 82 | + if if_f0 == 1: |
| 83 | + net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half) |
| 84 | + else: |
| 85 | + net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) |
63 | 86 | del self.net_g.enc_q |
64 | 87 | print(self.net_g.load_state_dict(cpt["weight"], strict=False)) |
65 | 88 | self.net_g.eval().to(device) |
66 | | - self.net_g.half() |
| 89 | + if(is_half==True): |
| 90 | + self.net_g=self.net_g.half() |
| 91 | + else: |
| 92 | + self.net_g=self.net_g.float() |
67 | 93 | except: |
68 | 94 | print(traceback.format_exc()) |
69 | 95 |
|
@@ -116,34 +142,33 @@ def infer(self, feats: torch.Tensor) -> np.ndarray: |
116 | 142 | inputs = { |
117 | 143 | "source": feats.half().to(device), |
118 | 144 | "padding_mask": padding_mask.to(device), |
119 | | - "output_layer": 9, # layer 9 |
| 145 | + "output_layer": 9 if self.version == "v1" else 12, |
120 | 146 | } |
121 | 147 | torch.cuda.synchronize() |
122 | 148 | with torch.no_grad(): |
123 | 149 | logits = self.model.extract_features(**inputs) |
124 | | - feats = self.model.final_proj(logits[0]) |
| 150 | + feats = model.final_proj(logits[0]) if self.version == "v1" else logits[0] |
125 | 151 |
|
126 | 152 | ####索引优化 |
127 | | - if hasattr(self, "index") and hasattr(self, "big_npy") and self.index_rate != 0: |
128 | | - npy = feats[0].cpu().numpy().astype("float32") |
129 | | - |
130 | | - # _, I = self.index.search(npy, 1) |
131 | | - # npy = self.big_npy[I.squeeze()].astype("float16") |
132 | | - |
133 | | - score, ix = self.index.search(npy, k=8) |
134 | | - weight = np.square(1 / score) |
135 | | - weight /= weight.sum(axis=1, keepdims=True) |
136 | | - npy = np.sum( |
137 | | - self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1 |
138 | | - ).astype("float16") |
139 | | - |
140 | | - feats = ( |
141 | | - torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate |
142 | | - + (1 - self.index_rate) * feats |
143 | | - ) |
144 | | - else: |
145 | | - print("index search FAIL or disabled") |
146 | | - |
| 153 | + try: |
| 154 | + if hasattr(self, "index") and hasattr(self, "big_npy") and self.index_rate != 0: |
| 155 | + npy = feats[0].cpu().numpy().astype("float32") |
| 156 | + score, ix = self.index.search(npy, k=8) |
| 157 | + weight = np.square(1 / score) |
| 158 | + weight /= weight.sum(axis=1, keepdims=True) |
| 159 | + npy = np.sum( |
| 160 | + self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1 |
| 161 | + ) |
| 162 | + if(is_half==True):npy=npy.astype("float16") |
| 163 | + feats = ( |
| 164 | + torch.from_numpy(npy).unsqueeze(0).to(device) * self.index_rate |
| 165 | + + (1 - self.index_rate) * feats |
| 166 | + ) |
| 167 | + else: |
| 168 | + print("index search FAIL or disabled") |
| 169 | + except: |
| 170 | + traceback.print_exc() |
| 171 | + print("index search FAIL") |
147 | 172 | feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) |
148 | 173 | torch.cuda.synchronize() |
149 | 174 | print(feats.shape) |
|
0 commit comments