Skip to content

Commit 67469e9

Browse files
timmoon10nv-kkudrynski
authored andcommitted
[FastPitch/PyT] Optimize CPU perf and remove GPU syncs
1 parent 8909d58 commit 67469e9

3 files changed

Lines changed: 68 additions & 58 deletions

File tree

PyTorch/SpeechSynthesis/FastPitch/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,12 @@ We're constantly refining and improving our performance on AI and HPC workloads
675675
676676
### Changelog
677677
678+
July 2022
679+
- Performance optimizations, speedups up to 2x (DGX-1) and 2.5x (DGX A100)
680+
681+
June 2022
682+
- MHA bug fix affecting models with > 1 attention heads
683+
678684
August 2021
679685
- Improved quality of synthesized audio
680686
- Added capability to automatically align audio to transcripts during training without a pre-trained Tacotron 2 aligning model

PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,69 +17,67 @@
1717

1818

1919
@jit(nopython=True)
20-
def mas(attn_map, width=1):
20+
def mas(log_attn_map, width=1):
2121
# assumes mel x text
22-
opt = np.zeros_like(attn_map)
23-
attn_map = np.log(attn_map)
24-
attn_map[0, 1:] = -np.inf
25-
log_p = np.zeros_like(attn_map)
26-
log_p[0, :] = attn_map[0, :]
27-
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
28-
for i in range(1, attn_map.shape[0]):
29-
for j in range(attn_map.shape[1]): # for each text dim
22+
opt = np.zeros_like(log_attn_map)
23+
log_attn_map = log_attn_map.copy()
24+
log_attn_map[0, 1:] = -np.inf
25+
log_p = np.zeros_like(log_attn_map)
26+
log_p[0, :] = log_attn_map[0, :]
27+
prev_ind = np.zeros_like(log_attn_map, dtype=np.int64)
28+
for i in range(1, log_attn_map.shape[0]):
29+
for j in range(log_attn_map.shape[1]): # for each text dim
3030
prev_j = np.arange(max(0, j-width), j+1)
3131
prev_log = np.array([log_p[i-1, prev_idx] for prev_idx in prev_j])
3232

3333
ind = np.argmax(prev_log)
34-
log_p[i, j] = attn_map[i, j] + prev_log[ind]
34+
log_p[i, j] = log_attn_map[i, j] + prev_log[ind]
3535
prev_ind[i, j] = prev_j[ind]
3636

3737
# now backtrack
38-
curr_text_idx = attn_map.shape[1]-1
39-
for i in range(attn_map.shape[0]-1, -1, -1):
38+
curr_text_idx = log_attn_map.shape[1]-1
39+
for i in range(log_attn_map.shape[0]-1, -1, -1):
4040
opt[i, curr_text_idx] = 1
4141
curr_text_idx = prev_ind[i, curr_text_idx]
4242
opt[0, curr_text_idx] = 1
4343
return opt
4444

4545

4646
@jit(nopython=True)
47-
def mas_width1(attn_map):
47+
def mas_width1(log_attn_map):
4848
"""mas with hardcoded width=1"""
4949
# assumes mel x text
50-
opt = np.zeros_like(attn_map)
51-
attn_map = np.log(attn_map)
52-
attn_map[0, 1:] = -np.inf
53-
log_p = np.zeros_like(attn_map)
54-
log_p[0, :] = attn_map[0, :]
55-
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
56-
for i in range(1, attn_map.shape[0]):
57-
for j in range(attn_map.shape[1]): # for each text dim
58-
prev_log = log_p[i-1, j]
59-
prev_j = j
60-
61-
if j-1 >= 0 and log_p[i-1, j-1] >= log_p[i-1, j]:
62-
prev_log = log_p[i-1, j-1]
63-
prev_j = j-1
64-
65-
log_p[i, j] = attn_map[i, j] + prev_log
66-
prev_ind[i, j] = prev_j
50+
neg_inf = log_attn_map.dtype.type(-np.inf)
51+
log_p = log_attn_map.copy()
52+
log_p[0, 1:] = neg_inf
53+
for i in range(1, log_p.shape[0]):
54+
prev_log1 = neg_inf
55+
for j in range(log_p.shape[1]):
56+
prev_log2 = log_p[i-1, j]
57+
log_p[i, j] += max(prev_log1, prev_log2)
58+
prev_log1 = prev_log2
6759

6860
# now backtrack
69-
curr_text_idx = attn_map.shape[1]-1
70-
for i in range(attn_map.shape[0]-1, -1, -1):
71-
opt[i, curr_text_idx] = 1
72-
curr_text_idx = prev_ind[i, curr_text_idx]
73-
opt[0, curr_text_idx] = 1
61+
opt = np.zeros_like(log_p)
62+
one = opt.dtype.type(1)
63+
j = log_p.shape[1]-1
64+
for i in range(log_p.shape[0]-1, 0, -1):
65+
opt[i, j] = one
66+
if log_p[i-1, j-1] >= log_p[i-1, j]:
67+
j -= 1
68+
if j == 0:
69+
opt[1:i, j] = one
70+
break
71+
opt[0, j] = one
7472
return opt
7573

7674

7775
@jit(nopython=True, parallel=True)
78-
def b_mas(b_attn_map, in_lens, out_lens, width=1):
76+
def b_mas(b_log_attn_map, in_lens, out_lens, width=1):
7977
assert width == 1
80-
attn_out = np.zeros_like(b_attn_map)
78+
attn_out = np.zeros_like(b_log_attn_map)
8179

82-
for b in prange(b_attn_map.shape[0]):
83-
out = mas_width1(b_attn_map[b, 0, :out_lens[b], :in_lens[b]])
80+
for b in prange(b_log_attn_map.shape[0]):
81+
out = mas_width1(b_log_attn_map[b, 0, :out_lens[b], :in_lens[b]])
8482
attn_out[b, 0, :out_lens[b], :in_lens[b]] = out
8583
return attn_out

PyTorch/SpeechSynthesis/FastPitch/fastpitch/model.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
from typing import Optional
2929

30+
import numpy as np
31+
3032
import torch
3133
import torch.nn as nn
3234
import torch.nn.functional as F
@@ -52,7 +54,7 @@ def regulate_len(durations, enc_out, pace: float = 1.0,
5254
dim=1)[:, None, :]
5355
reps_cumsum = reps_cumsum.to(dtype)
5456

55-
range_ = torch.arange(max_len).to(enc_out.device)[None, :, None]
57+
range_ = torch.arange(max_len, device=enc_out.device)[None, :, None]
5658
mult = ((reps_cumsum[:, :, :-1] <= range_) &
5759
(reps_cumsum[:, :, 1:] > range_))
5860
mult = mult.to(dtype)
@@ -218,13 +220,17 @@ def binarize_attention(self, attn, in_lens, out_lens):
218220
"""
219221
b_size = attn.shape[0]
220222
with torch.no_grad():
221-
attn_cpu = attn.data.cpu().numpy()
222-
attn_out = torch.zeros_like(attn)
223+
attn_out_cpu = np.zeros(attn.data.shape, dtype=np.float32)
224+
log_attn_cpu = torch.log(attn.data).to(device='cpu', dtype=torch.float32)
225+
log_attn_cpu = log_attn_cpu.numpy()
226+
out_lens_cpu = out_lens.cpu()
227+
in_lens_cpu = in_lens.cpu()
223228
for ind in range(b_size):
224229
hard_attn = mas_width1(
225-
attn_cpu[ind, 0, :out_lens[ind], :in_lens[ind]])
226-
attn_out[ind, 0, :out_lens[ind], :in_lens[ind]] = torch.tensor(
227-
hard_attn, device=attn.get_device())
230+
log_attn_cpu[ind, 0, :out_lens_cpu[ind], :in_lens_cpu[ind]])
231+
attn_out_cpu[ind, 0, :out_lens_cpu[ind], :in_lens_cpu[ind]] = hard_attn
232+
attn_out = torch.tensor(
233+
attn_out_cpu, device=attn.get_device(), dtype=attn.dtype)
228234
return attn_out
229235

230236
def binarize_attention_parallel(self, attn, in_lens, out_lens):
@@ -235,8 +241,8 @@ def binarize_attention_parallel(self, attn, in_lens, out_lens):
235241
attn: B x 1 x max_mel_len x max_text_len
236242
"""
237243
with torch.no_grad():
238-
attn_cpu = attn.data.cpu().numpy()
239-
attn_out = b_mas(attn_cpu, in_lens.cpu().numpy(),
244+
log_attn_cpu = torch.log(attn.data).cpu().numpy()
245+
attn_out = b_mas(log_attn_cpu, in_lens.cpu().numpy(),
240246
out_lens.cpu().numpy(), width=1)
241247
return torch.from_numpy(attn_out).to(attn.get_device())
242248

@@ -245,6 +251,7 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75):
245251
(inputs, input_lens, mel_tgt, mel_lens, pitch_dense, energy_dense,
246252
speaker, attn_prior, audiopaths) = inputs
247253

254+
text_max_len = inputs.size(1)
248255
mel_max_len = mel_tgt.size(2)
249256

250257
# Calculate speaker embedding
@@ -257,33 +264,32 @@ def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75):
257264
# Input FFT
258265
enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb)
259266

267+
# Predict durations
268+
log_dur_pred = self.duration_predictor(enc_out, enc_mask).squeeze(-1)
269+
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
270+
271+
# Predict pitch
272+
pitch_pred = self.pitch_predictor(enc_out, enc_mask).permute(0, 2, 1)
273+
260274
# Alignment
261275
text_emb = self.encoder.word_emb(inputs)
262276

263277
# make sure to do the alignments before folding
264-
attn_mask = mask_from_lens(input_lens)[..., None] == 0
278+
attn_mask = mask_from_lens(input_lens, max_len=text_max_len)
279+
attn_mask = attn_mask[..., None] == 0
265280
# attn_mask should be 1 for unused timesteps in the text_enc_w_spkvec tensor
266281

267282
attn_soft, attn_logprob = self.attention(
268283
mel_tgt, text_emb.permute(0, 2, 1), mel_lens, attn_mask,
269284
key_lens=input_lens, keys_encoded=enc_out, attn_prior=attn_prior)
270285

271-
attn_hard = self.binarize_attention_parallel(
272-
attn_soft, input_lens, mel_lens)
286+
attn_hard = self.binarize_attention(attn_soft, input_lens, mel_lens)
273287

274288
# Viterbi --> durations
275289
attn_hard_dur = attn_hard.sum(2)[:, 0, :]
276290
dur_tgt = attn_hard_dur
277-
278291
assert torch.all(torch.eq(dur_tgt.sum(dim=1), mel_lens))
279292

280-
# Predict durations
281-
log_dur_pred = self.duration_predictor(enc_out, enc_mask).squeeze(-1)
282-
dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration)
283-
284-
# Predict pitch
285-
pitch_pred = self.pitch_predictor(enc_out, enc_mask).permute(0, 2, 1)
286-
287293
# Average pitch over characters
288294
pitch_tgt = average_pitch(pitch_dense, dur_tgt)
289295

0 commit comments

Comments
 (0)