2727
2828from typing import Optional
2929
30+ import numpy as np
31+
3032import torch
3133import torch .nn as nn
3234import torch .nn .functional as F
@@ -52,7 +54,7 @@ def regulate_len(durations, enc_out, pace: float = 1.0,
5254 dim = 1 )[:, None , :]
5355 reps_cumsum = reps_cumsum .to (dtype )
5456
55- range_ = torch .arange (max_len ). to ( enc_out .device )[None , :, None ]
57+ range_ = torch .arange (max_len , device = enc_out .device )[None , :, None ]
5658 mult = ((reps_cumsum [:, :, :- 1 ] <= range_ ) &
5759 (reps_cumsum [:, :, 1 :] > range_ ))
5860 mult = mult .to (dtype )
@@ -218,13 +220,17 @@ def binarize_attention(self, attn, in_lens, out_lens):
218220 """
219221 b_size = attn .shape [0 ]
220222 with torch .no_grad ():
221- attn_cpu = attn .data .cpu ().numpy ()
222- attn_out = torch .zeros_like (attn )
223+ attn_out_cpu = np .zeros (attn .data .shape , dtype = np .float32 )
224+ log_attn_cpu = torch .log (attn .data ).to (device = 'cpu' , dtype = torch .float32 )
225+ log_attn_cpu = log_attn_cpu .numpy ()
226+ out_lens_cpu = out_lens .cpu ()
227+ in_lens_cpu = in_lens .cpu ()
223228 for ind in range (b_size ):
224229 hard_attn = mas_width1 (
225- attn_cpu [ind , 0 , :out_lens [ind ], :in_lens [ind ]])
226- attn_out [ind , 0 , :out_lens [ind ], :in_lens [ind ]] = torch .tensor (
227- hard_attn , device = attn .get_device ())
230+ log_attn_cpu [ind , 0 , :out_lens_cpu [ind ], :in_lens_cpu [ind ]])
231+ attn_out_cpu [ind , 0 , :out_lens_cpu [ind ], :in_lens_cpu [ind ]] = hard_attn
232+ attn_out = torch .tensor (
233+ attn_out_cpu , device = attn .get_device (), dtype = attn .dtype )
228234 return attn_out
229235
230236 def binarize_attention_parallel (self , attn , in_lens , out_lens ):
@@ -235,8 +241,8 @@ def binarize_attention_parallel(self, attn, in_lens, out_lens):
235241 attn: B x 1 x max_mel_len x max_text_len
236242 """
237243 with torch .no_grad ():
238- attn_cpu = attn .data .cpu ().numpy ()
239- attn_out = b_mas (attn_cpu , in_lens .cpu ().numpy (),
244+ log_attn_cpu = torch . log ( attn .data ) .cpu ().numpy ()
245+ attn_out = b_mas (log_attn_cpu , in_lens .cpu ().numpy (),
240246 out_lens .cpu ().numpy (), width = 1 )
241247 return torch .from_numpy (attn_out ).to (attn .get_device ())
242248
@@ -245,6 +251,7 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75):
245251 (inputs , input_lens , mel_tgt , mel_lens , pitch_dense , energy_dense ,
246252 speaker , attn_prior , audiopaths ) = inputs
247253
254+ text_max_len = inputs .size (1 )
248255 mel_max_len = mel_tgt .size (2 )
249256
250257 # Calculate speaker embedding
@@ -257,33 +264,32 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75):
257264 # Input FFT
258265 enc_out , enc_mask = self .encoder (inputs , conditioning = spk_emb )
259266
267+ # Predict durations
268+ log_dur_pred = self .duration_predictor (enc_out , enc_mask ).squeeze (- 1 )
269+ dur_pred = torch .clamp (torch .exp (log_dur_pred ) - 1 , 0 , max_duration )
270+
271+ # Predict pitch
272+ pitch_pred = self .pitch_predictor (enc_out , enc_mask ).permute (0 , 2 , 1 )
273+
260274 # Alignment
261275 text_emb = self .encoder .word_emb (inputs )
262276
263277 # make sure to do the alignments before folding
264- attn_mask = mask_from_lens (input_lens )[..., None ] == 0
278+ attn_mask = mask_from_lens (input_lens , max_len = text_max_len )
279+ attn_mask = attn_mask [..., None ] == 0
265280 # attn_mask should be 1 for unused timesteps in the text_enc_w_spkvec tensor
266281
267282 attn_soft , attn_logprob = self .attention (
268283 mel_tgt , text_emb .permute (0 , 2 , 1 ), mel_lens , attn_mask ,
269284 key_lens = input_lens , keys_encoded = enc_out , attn_prior = attn_prior )
270285
271- attn_hard = self .binarize_attention_parallel (
272- attn_soft , input_lens , mel_lens )
286+ attn_hard = self .binarize_attention (attn_soft , input_lens , mel_lens )
273287
274288 # Viterbi --> durations
275289 attn_hard_dur = attn_hard .sum (2 )[:, 0 , :]
276290 dur_tgt = attn_hard_dur
277-
278291 assert torch .all (torch .eq (dur_tgt .sum (dim = 1 ), mel_lens ))
279292
280- # Predict durations
281- log_dur_pred = self .duration_predictor (enc_out , enc_mask ).squeeze (- 1 )
282- dur_pred = torch .clamp (torch .exp (log_dur_pred ) - 1 , 0 , max_duration )
283-
284- # Predict pitch
285- pitch_pred = self .pitch_predictor (enc_out , enc_mask ).permute (0 , 2 , 1 )
286-
287293 # Average pitch over characters
288294 pitch_tgt = average_pitch (pitch_dense , dur_tgt )
289295
0 commit comments