1- import torch , numpy as np
1+ import torch , numpy as np , pdb
22import torch .nn as nn
33import torch .nn .functional as F
4-
5-
4+ import torch ,pdb
5+ import numpy as np
6+ import torch .nn .functional as F
7+ from scipy .signal import get_window
8+ from librosa .util import pad_center , tiny ,normalize
9+ ###stft codes from https://github.com/pseeth/torch-stft/blob/master/torch_stft/util.py
10+ def window_sumsquare (window , n_frames , hop_length = 200 , win_length = 800 ,
11+ n_fft = 800 , dtype = np .float32 , norm = None ):
12+ """
13+ # from librosa 0.6
14+ Compute the sum-square envelope of a window function at a given hop length.
15+ This is used to estimate modulation effects induced by windowing
16+ observations in short-time fourier transforms.
17+ Parameters
18+ ----------
19+ window : string, tuple, number, callable, or list-like
20+ Window specification, as in `get_window`
21+ n_frames : int > 0
22+ The number of analysis frames
23+ hop_length : int > 0
24+ The number of samples to advance between frames
25+ win_length : [optional]
26+ The length of the window function. By default, this matches `n_fft`.
27+ n_fft : int > 0
28+ The length of each analysis frame.
29+ dtype : np.dtype
30+ The data type of the output
31+ Returns
32+ -------
33+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
34+ The sum-squared envelope of the window function
35+ """
36+ if win_length is None :
37+ win_length = n_fft
38+
39+ n = n_fft + hop_length * (n_frames - 1 )
40+ x = np .zeros (n , dtype = dtype )
41+
42+ # Compute the squared window at the desired length
43+ win_sq = get_window (window , win_length , fftbins = True )
44+ win_sq = normalize (win_sq , norm = norm )** 2
45+ win_sq = pad_center (win_sq , n_fft )
46+
47+ # Fill the envelope
48+ for i in range (n_frames ):
49+ sample = i * hop_length
50+ x [sample :min (n , sample + n_fft )] += win_sq [:max (0 , min (n_fft , n - sample ))]
51+ return x
52+
53+ class STFT (torch .nn .Module ):
54+ def __init__ (self , filter_length = 1024 , hop_length = 512 , win_length = None ,
55+ window = 'hann' ):
56+ """
57+ This module implements an STFT using 1D convolution and 1D transpose convolutions.
58+ This is a bit tricky so there are some cases that probably won't work as working
59+ out the same sizes before and after in all overlap add setups is tough. Right now,
60+ this code should work with hop lengths that are half the filter length (50% overlap
61+ between frames).
62+
63+ Keyword Arguments:
64+ filter_length {int} -- Length of filters used (default: {1024})
65+ hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
66+ win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
67+ equals the filter length). (default: {None})
68+ window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
69+ (default: {'hann'})
70+ """
71+ super (STFT , self ).__init__ ()
72+ self .filter_length = filter_length
73+ self .hop_length = hop_length
74+ self .win_length = win_length if win_length else filter_length
75+ self .window = window
76+ self .forward_transform = None
77+ self .pad_amount = int (self .filter_length / 2 )
78+ scale = self .filter_length / self .hop_length
79+ fourier_basis = np .fft .fft (np .eye (self .filter_length ))
80+
81+ cutoff = int ((self .filter_length / 2 + 1 ))
82+ fourier_basis = np .vstack ([np .real (fourier_basis [:cutoff , :]),np .imag (fourier_basis [:cutoff , :])])
83+ forward_basis = torch .FloatTensor (fourier_basis [:, None , :])
84+ inverse_basis = torch .FloatTensor (
85+ np .linalg .pinv (scale * fourier_basis ).T [:, None , :])
86+
87+ assert (filter_length >= self .win_length )
88+ # get window and zero center pad it to filter_length
89+ fft_window = get_window (window , self .win_length , fftbins = True )
90+ fft_window = pad_center (fft_window , size = filter_length )
91+ fft_window = torch .from_numpy (fft_window ).float ()
92+
93+ # window the bases
94+ forward_basis *= fft_window
95+ inverse_basis *= fft_window
96+
97+ self .register_buffer ('forward_basis' , forward_basis .float ())
98+ self .register_buffer ('inverse_basis' , inverse_basis .float ())
99+
100+ def transform (self , input_data ):
101+ """Take input data (audio) to STFT domain.
102+
103+ Arguments:
104+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
105+
106+ Returns:
107+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
108+ num_frequencies, num_frames)
109+ phase {tensor} -- Phase of STFT with shape (num_batch,
110+ num_frequencies, num_frames)
111+ """
112+ num_batches = input_data .shape [0 ]
113+ num_samples = input_data .shape [- 1 ]
114+
115+ self .num_samples = num_samples
116+
117+ # similar to librosa, reflect-pad the input
118+ input_data = input_data .view (num_batches , 1 , num_samples )
119+ # print(1234,input_data.shape)
120+ input_data = F .pad (input_data .unsqueeze (1 ),(self .pad_amount , self .pad_amount , 0 , 0 ,0 ,0 ),mode = 'reflect' ).squeeze (1 )
121+ # print(2333,input_data.shape,self.forward_basis.shape,self.hop_length)
122+ # pdb.set_trace()
123+ forward_transform = F .conv1d (
124+ input_data ,
125+ self .forward_basis ,
126+ stride = self .hop_length ,
127+ padding = 0 )
128+
129+ cutoff = int ((self .filter_length / 2 ) + 1 )
130+ real_part = forward_transform [:, :cutoff , :]
131+ imag_part = forward_transform [:, cutoff :, :]
132+
133+ magnitude = torch .sqrt (real_part ** 2 + imag_part ** 2 )
134+ # phase = torch.atan2(imag_part.data, real_part.data)
135+
136+ return magnitude #, phase
137+
138+ def inverse (self , magnitude , phase ):
139+ """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
140+ by the ```transform``` function.
141+
142+ Arguments:
143+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
144+ num_frequencies, num_frames)
145+ phase {tensor} -- Phase of STFT with shape (num_batch,
146+ num_frequencies, num_frames)
147+
148+ Returns:
149+ inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
150+ shape (num_batch, num_samples)
151+ """
152+ recombine_magnitude_phase = torch .cat (
153+ [magnitude * torch .cos (phase ), magnitude * torch .sin (phase )], dim = 1 )
154+
155+ inverse_transform = F .conv_transpose1d (
156+ recombine_magnitude_phase ,
157+ self .inverse_basis ,
158+ stride = self .hop_length ,
159+ padding = 0 )
160+
161+ if self .window is not None :
162+ window_sum = window_sumsquare (
163+ self .window , magnitude .size (- 1 ), hop_length = self .hop_length ,
164+ win_length = self .win_length , n_fft = self .filter_length ,
165+ dtype = np .float32 )
166+ # remove modulation effects
167+ approx_nonzero_indices = torch .from_numpy (
168+ np .where (window_sum > tiny (window_sum ))[0 ])
169+ window_sum = torch .from_numpy (window_sum ).to (inverse_transform .device )
170+ inverse_transform [:, :, approx_nonzero_indices ] /= window_sum [approx_nonzero_indices ]
171+
172+ # scale by hop ratio
173+ inverse_transform *= float (self .filter_length ) / self .hop_length
174+
175+ inverse_transform = inverse_transform [..., self .pad_amount :]
176+ inverse_transform = inverse_transform [..., :self .num_samples ]
177+ inverse_transform = inverse_transform .squeeze (1 )
178+
179+ return inverse_transform
180+
181+ def forward (self , input_data ):
182+ """Take input data (audio) to STFT domain and then back to audio.
183+
184+ Arguments:
185+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
186+
187+ Returns:
188+ reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
189+ shape (num_batch, num_samples)
190+ """
191+ self .magnitude , self .phase = self .transform (input_data )
192+ reconstruction = self .inverse (self .magnitude , self .phase )
193+ return reconstruction
194+ from time import time as ttime
6195class BiGRU (nn .Module ):
7196 def __init__ (self , input_features , hidden_features , num_layers ):
8197 super (BiGRU , self ).__init__ ()
@@ -250,9 +439,11 @@ def __init__(
250439 )
251440
252441 def forward (self , mel ):
442+ # print(mel.shape)
253443 mel = mel .transpose (- 1 , - 2 ).unsqueeze (1 )
254444 x = self .cnn (self .unet (mel )).transpose (1 , 2 ).flatten (- 2 )
255445 x = self .fc (x )
446+ # print(x.shape)
256447 return x
257448
258449
@@ -301,18 +492,33 @@ def forward(self, audio, keyshift=0, speed=1, center=True):
301492 keyshift_key = str (keyshift ) + "_" + str (audio .device )
302493 if keyshift_key not in self .hann_window :
303494 self .hann_window [keyshift_key ] = torch .hann_window (win_length_new ).to (
495+ # "cpu"if(audio.device.type=="privateuseone") else audio.device
304496 audio .device
305497 )
306- fft = torch .stft (
307- audio ,
308- n_fft = n_fft_new ,
309- hop_length = hop_length_new ,
310- win_length = win_length_new ,
311- window = self .hann_window [keyshift_key ],
312- center = center ,
313- return_complex = True ,
314- )
315- magnitude = torch .sqrt (fft .real .pow (2 ) + fft .imag .pow (2 ))
498+ # fft = torch.stft(#doesn't support pytorch_dml
499+ # # audio.cpu() if(audio.device.type=="privateuseone")else audio,
500+ # audio,
501+ # n_fft=n_fft_new,
502+ # hop_length=hop_length_new,
503+ # win_length=win_length_new,
504+ # window=self.hann_window[keyshift_key],
505+ # center=center,
506+ # return_complex=True,
507+ # )
508+ # magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
509+ # print(1111111111)
510+ # print(222222222222222,audio.device,self.is_half)
511+ if hasattr (self , "stft" ) == False :
512+ # print(n_fft_new,hop_length_new,win_length_new,audio.shape)
513+ self .stft = STFT (
514+ filter_length = n_fft_new ,
515+ hop_length = hop_length_new ,
516+ win_length = win_length_new ,
517+ window = 'hann'
518+ ).to (audio .device )
519+ magnitude = self .stft .transform (audio )#phase
520+ # if (audio.device.type == "privateuseone"):
521+ # magnitude=magnitude.to(audio.device)
316522 if keyshift != 0 :
317523 size = self .n_fft // 2 + 1
318524 resize = magnitude .size (1 )
@@ -323,19 +529,13 @@ def forward(self, audio, keyshift=0, speed=1, center=True):
323529 if self .is_half == True :
324530 mel_output = mel_output .half ()
325531 log_mel_spec = torch .log (torch .clamp (mel_output , min = self .clamp ))
532+ # print(log_mel_spec.device.type)
326533 return log_mel_spec
327534
328535
329536class RMVPE :
330537 def __init__ (self , model_path , is_half , device = None ):
331538 self .resample_kernel = {}
332- model = E2E (4 , 1 , (2 , 2 ))
333- ckpt = torch .load (model_path , map_location = "cpu" )
334- model .load_state_dict (ckpt )
335- model .eval ()
336- if is_half == True :
337- model = model .half ()
338- self .model = model
339539 self .resample_kernel = {}
340540 self .is_half = is_half
341541 if device is None :
@@ -344,7 +544,19 @@ def __init__(self, model_path, is_half, device=None):
344544 self .mel_extractor = MelSpectrogram (
345545 is_half , 128 , 16000 , 1024 , 160 , None , 30 , 8000
346546 ).to (device )
347- self .model = self .model .to (device )
547+ if ("privateuseone" in str (device )):
548+ import onnxruntime as ort
549+ ort_session = ort .InferenceSession ("rmvpe.onnx" , providers = ["DmlExecutionProvider" ])
550+ self .model = ort_session
551+ else :
552+ model = E2E (4 , 1 , (2 , 2 ))
553+ ckpt = torch .load (model_path , map_location = "cpu" )
554+ model .load_state_dict (ckpt )
555+ model .eval ()
556+ if is_half == True :
557+ model = model .half ()
558+ self .model = model
559+ self .model = self .model .to (device )
348560 cents_mapping = 20 * np .arange (360 ) + 1997.3794084376191
349561 self .cents_mapping = np .pad (cents_mapping , (4 , 4 )) # 368
350562
@@ -354,7 +566,12 @@ def mel2hidden(self, mel):
354566 mel = F .pad (
355567 mel , (0 , 32 * ((n_frames - 1 ) // 32 + 1 ) - n_frames ), mode = "reflect"
356568 )
357- hidden = self .model (mel )
569+ if ("privateuseone" in str (self .device ) ):
570+ onnx_input_name = self .model .get_inputs ()[0 ].name
571+ onnx_outputs_names = self .model .get_outputs ()[0 ].name
572+ hidden = self .model .run ([onnx_outputs_names ], input_feed = {onnx_input_name : mel .cpu ().numpy ()})[0 ]
573+ else :
574+ hidden = self .model (mel )
358575 return hidden [:, :n_frames ]
359576
360577 def decode (self , hidden , thred = 0.03 ):
@@ -365,21 +582,26 @@ def decode(self, hidden, thred=0.03):
365582 return f0
366583
367584 def infer_from_audio (self , audio , thred = 0.03 ):
368- audio = torch .from_numpy (audio ).float ().to (self .device ).unsqueeze (0 )
369585 # torch.cuda.synchronize()
370- # t0=ttime()
371- mel = self .mel_extractor (audio , center = True )
586+ t0 = ttime ()
587+ mel = self .mel_extractor (torch .from_numpy (audio ).float ().to (self .device ).unsqueeze (0 ), center = True )
588+ # print(123123123,mel.device.type)
372589 # torch.cuda.synchronize()
373- # t1=ttime()
590+ t1 = ttime ()
374591 hidden = self .mel2hidden (mel )
375592 # torch.cuda.synchronize()
376- # t2=ttime()
377- hidden = hidden .squeeze (0 ).cpu ().numpy ()
593+ t2 = ttime ()
594+ # print(234234,hidden.device.type)
595+ if ("privateuseone" not in str (self .device )):
596+ hidden = hidden .squeeze (0 ).cpu ().numpy ()
597+ else :
598+ hidden = hidden [0 ]
378599 if self .is_half == True :
379600 hidden = hidden .astype ("float32" )
601+
380602 f0 = self .decode (hidden , thred = thred )
381603 # torch.cuda.synchronize()
382- # t3=ttime()
604+ t3 = ttime ()
383605 # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
384606 return f0
385607
@@ -410,22 +632,23 @@ def to_local_average_cents(self, salience, thred=0.05):
410632 return devided
411633
412634
413- # if __name__ == '__main__':
414- # audio, sampling_rate = sf.read("卢本伟语录~1.wav")
415- # if len(audio.shape) > 1:
416- # audio = librosa.to_mono(audio.transpose(1, 0))
417- # audio_bak = audio.copy()
418- # if sampling_rate != 16000:
419- # audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
420- # model_path = "/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/test-RMVPE/weights/rmvpe_llc_half.pt"
421- # thred = 0.03 # 0.01
422- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
423- # rmvpe = RMVPE(model_path,is_half=False, device=device)
424- # t0=ttime()
425- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
426- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
427- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
428- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
429- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
430- # t1=ttime()
431- # print(f0.shape,t1-t0)
635+ if __name__ == '__main__' :
636+ import soundfile as sf , librosa
637+ audio , sampling_rate = sf .read (r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav" )
638+ if len (audio .shape ) > 1 :
639+ audio = librosa .to_mono (audio .transpose (1 , 0 ))
640+ audio_bak = audio .copy ()
641+ if sampling_rate != 16000 :
642+ audio = librosa .resample (audio , orig_sr = sampling_rate , target_sr = 16000 )
643+ model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt"
644+ thred = 0.03 # 0.01
645+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
646+ rmvpe = RMVPE (model_path ,is_half = False , device = device )
647+ t0 = ttime ()
648+ f0 = rmvpe .infer_from_audio (audio , thred = thred )
649+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
650+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
651+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
652+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
653+ t1 = ttime ()
654+ print (f0 .shape ,t1 - t0 )
0 commit comments