Skip to content

Commit 84b22b3

Browse files
alancuckinv-kkudrynski
authored andcommitted
[FastPitch/PyT] Resolve perf regression on DGX A100 + new perf tweaks
1 parent 469df9b commit 84b22b3

8 files changed

Lines changed: 158 additions & 50 deletions

File tree

PyTorch/SpeechRecognition/Jasper/common/filter_warnings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""Mutes known and unrelated PyTorch warnings.
16+
17+
The warnings module keeps a list of filters. Importing it as late as possible
18+
prevents its filters from being overriden.
19+
"""
20+
1521
import warnings
1622

1723

PyTorch/SpeechRecognition/QuartzNet/common/filter_warnings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""Mutes known and unrelated PyTorch warnings.
16+
17+
The warnings module keeps a list of filters. Importing it as late as possible
18+
prevents its filters from being overriden.
19+
"""
20+
1521
import warnings
1622

1723

PyTorch/SpeechSynthesis/FastPitch/common/filter_warnings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""Mutes known and unrelated PyTorch warnings.
16+
17+
The warnings module keeps a list of filters. Importing it as late as possible
18+
prevents its filters from being overriden.
19+
"""
20+
1521
import warnings
1622

1723

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Data pipeline elements which wrap the data N times
16+
17+
A RepeatedDataLoader resets its iterator less frequently. This saves time
18+
on multi-GPU platforms and is invisible to the training loop.
19+
20+
NOTE: Repeating puts a block of (len(dataset) * repeats) int64s into RAM.
21+
Do not use more repeats than necessary (e.g., 10**6 to simulate infinity).
22+
"""
23+
24+
import itertools
25+
26+
from torch.utils.data import DataLoader
27+
from torch.utils.data.distributed import DistributedSampler
28+
29+
30+
class RepeatedDataLoader(DataLoader):
31+
def __init__(self, repeats, *args, **kwargs):
32+
self.repeats = repeats
33+
super().__init__(*args, **kwargs)
34+
35+
def __iter__(self):
36+
if self._iterator is None or self.repeats_done >= self.repeats:
37+
self.repeats_done = 1
38+
return super().__iter__()
39+
else:
40+
self.repeats_done += 1
41+
return self._iterator
42+
43+
44+
class RepeatedDistributedSampler(DistributedSampler):
45+
def __init__(self, repeats, *args, **kwargs):
46+
self.repeats = repeats
47+
assert self.repeats <= 10000, "Too many repeats overload RAM."
48+
super().__init__(*args, **kwargs)
49+
50+
def __iter__(self):
51+
# Draw indices for `self.repeats` epochs forward
52+
start_epoch = self.epoch
53+
iters = []
54+
for r in range(self.repeats):
55+
self.set_epoch(start_epoch + r)
56+
iters.append(super().__iter__())
57+
self.set_epoch(start_epoch)
58+
59+
return itertools.chain.from_iterable(iters)

PyTorch/SpeechSynthesis/FastPitch/scripts/inference_example.sh

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,27 @@ export CUDNN_V8_API_ENABLED=1
1212
: ${CPU:=false}
1313
: ${PHONE:=true}
1414

15-
# Mel-spectrogram generator (optional)
16-
: ${FASTPITCH="pretrained_models/fastpitch/nvidia_fastpitch_210824.pt"}
15+
# Paths to pre-trained models downloadable from NVIDIA NGC (LJSpeech-1.1)
16+
FASTPITCH_LJ="pretrained_models/fastpitch/nvidia_fastpitch_210824.pt"
17+
HIFIGAN_LJ="pretrained_models/hifigan/hifigan_gen_checkpoint_10000_ft.pt"
18+
WAVEGLOW_LJ="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"
1719

18-
# Vocoder; set only one
19-
: ${WAVEGLOW="pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt"}
20-
: ${HIFIGAN=""}
20+
# Mel-spectrogram generator (optional; can synthesize from ground-truth spectrograms)
21+
: ${FASTPITCH=$FASTPITCH_LJ}
2122

22-
[[ "$FASTPITCH" == "pretrained_models/fastpitch/nvidia_fastpitch_210824.pt" && ! -f "$FASTPITCH" ]] && { echo "Downloading $FASTPITCH from NGC..."; bash scripts/download_models.sh fastpitch; }
23-
[[ "$WAVEGLOW" == "pretrained_models/waveglow/nvidia_waveglow256pyt_fp16.pt" && ! -f "$WAVEGLOW" ]] && { echo "Downloading $WAVEGLOW from NGC..."; bash scripts/download_models.sh waveglow; }
23+
# Vocoder (set only one)
24+
: ${HIFIGAN=$HIFIGAN_LJ}
25+
# : ${WAVEGLOW=$WAVEGLOW_LJ}
26+
27+
[[ "$FASTPITCH" == "$FASTPITCH_LJ" && ! -f "$FASTPITCH" ]] && { echo "Downloading $FASTPITCH from NGC..."; bash scripts/download_models.sh fastpitch; }
28+
[[ "$WAVEGLOW" == "$WAVEGLOW_LJ" && ! -f "$WAVEGLOW" ]] && { echo "Downloading $WAVEGLOW from NGC..."; bash scripts/download_models.sh waveglow; }
29+
[[ "$HIFIGAN" == "$HIFIGAN_LJ" && ! -f "$HIFIGAN" ]] && { echo "Downloading $HIFIGAN from NGC..."; bash scripts/download_models.sh hifigan-finetuned-fastpitch; }
30+
31+
if [[ "$HIFIGAN" == "$HIFIGAN_LJ" && "$FASTPITCH" != "$FASTPITCH_LJ" ]]; then
32+
echo -e "\nNOTE: Using HiFi-GAN checkpoint trained for the LJSpeech-1.1 dataset."
33+
echo -e "NOTE: If you're using a different dataset, consider training a new HiFi-GAN model or switch to WaveGlow."
34+
echo -e "NOTE: See $0 for details.\n"
35+
fi
2436

2537
# Synthesis
2638
: ${SPEAKER:=0}

PyTorch/SpeechSynthesis/FastPitch/scripts/train.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ ARGS+=" --weight-decay 1e-6"
6161
ARGS+=" --grad-clip-thresh 1000.0"
6262
ARGS+=" --dur-predictor-loss-scale 0.1"
6363
ARGS+=" --pitch-predictor-loss-scale 0.1"
64+
ARGS+=" --trainloader-repeats 100"
65+
ARGS+=" --validation-freq 10"
6466

6567
# Autoalign & new features
6668
ARGS+=" --kl-loss-start-epoch 0"

PyTorch/SpeechSynthesis/FastPitch/train.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import os
3131
import time
3232
from collections import defaultdict, OrderedDict
33+
from itertools import cycle
3334

3435
import numpy as np
3536
import torch
@@ -43,6 +44,8 @@
4344
import common.tb_dllogger as logger
4445
import models
4546
from common.tb_dllogger import log
47+
from common.repeated_dataloader import (RepeatedDataLoader,
48+
RepeatedDistributedSampler)
4649
from common.text import cmudict
4750
from common.utils import BenchmarkStats, Checkpointer, prepare_tmp
4851
from fastpitch.attn_loss_function import AttentionBinarizationLoss
@@ -90,6 +93,8 @@ def parse_args(parser):
9093
help='Gradually increase the hard attention loss term')
9194
train.add_argument('--benchmark-epochs-num', type=int, default=20,
9295
help='Number of epochs for calculating final stats')
96+
train.add_argument('--validation-freq', type=int, default=1,
97+
help='Validate every N epochs to use less compute')
9398

9499
opt = parser.add_argument_group('optimization setup')
95100
opt.add_argument('--optimizer', type=str, default='lamb',
@@ -132,6 +137,10 @@ def parse_args(parser):
132137
help='Capture leading silence with a space token')
133138
data.add_argument('--append-space-to-text', action='store_true',
134139
help='Capture trailing silence with a space token')
140+
data.add_argument('--num-workers', type=int, default=6,
141+
help='Subprocesses for train and val DataLoaders')
142+
data.add_argument('--trainloader-repeats', type=int, default=100,
143+
help='Repeats the dataset to prolong epochs')
135144

136145
cond = parser.add_argument_group('data for conditioning')
137146
cond.add_argument('--n-speakers', type=int, default=1,
@@ -194,19 +203,13 @@ def init_distributed(args, world_size, rank):
194203
print("Done initializing distributed training")
195204

196205

197-
def validate(model, epoch, total_iter, criterion, valset, batch_size,
198-
collate_fn, distributed_run, batch_to_gpu, ema=False):
199-
"""Handles all the validation scoring and printing"""
206+
def validate(model, epoch, total_iter, criterion, val_loader, distributed_run,
207+
batch_to_gpu, ema=False):
200208
was_training = model.training
201209
model.eval()
202210

203211
tik = time.perf_counter()
204212
with torch.no_grad():
205-
val_sampler = DistributedSampler(valset) if distributed_run else None
206-
val_loader = DataLoader(valset, num_workers=4, shuffle=False,
207-
sampler=val_sampler,
208-
batch_size=batch_size, pin_memory=False,
209-
collate_fn=collate_fn)
210213
val_meta = defaultdict(float)
211214
val_num_frames = 0
212215
for i, batch in enumerate(val_loader):
@@ -221,9 +224,9 @@ def validate(model, epoch, total_iter, criterion, valset, batch_size,
221224
else:
222225
for k, v in meta.items():
223226
val_meta[k] += v
224-
val_num_frames = num_frames.item()
227+
val_num_frames += num_frames.item()
225228

226-
val_meta = {k: v / len(valset) for k, v in val_meta.items()}
229+
val_meta = {k: v / len(val_loader.dataset) for k, v in val_meta.items()}
227230

228231
val_meta['took'] = time.perf_counter() - tik
229232

@@ -232,7 +235,7 @@ def validate(model, epoch, total_iter, criterion, valset, batch_size,
232235
data=OrderedDict([
233236
('loss', val_meta['loss'].item()),
234237
('mel_loss', val_meta['mel_loss'].item()),
235-
('frames/s', num_frames.item() / val_meta['took']),
238+
('frames/s', val_num_frames / val_meta['took']),
236239
('took', val_meta['took'])]),
237240
)
238241

@@ -313,6 +316,11 @@ def main():
313316

314317
if distributed_run:
315318
init_distributed(args, args.world_size, args.local_rank)
319+
else:
320+
if args.trainloader_repeats > 1:
321+
print('WARNING: Disabled --trainloader-repeats, supported only for'
322+
' multi-GPU data loading.')
323+
args.trainloader_repeats = 1
316324

317325
device = torch.device('cuda' if args.cuda else 'cpu')
318326
model_config = models.get_model_config('FastPitch', args)
@@ -345,7 +353,7 @@ def main():
345353
model, device_ids=[args.local_rank], output_device=args.local_rank,
346354
find_unused_parameters=True)
347355

348-
train_state = {'epoch': 1, 'total_iter': 0}
356+
train_state = {'epoch': 1, 'total_iter': 1}
349357
checkpointer = Checkpointer(args.output, args.keep_milestones)
350358

351359
checkpointer.maybe_load(model, optimizer, scaler, train_state, args,
@@ -368,21 +376,26 @@ def main():
368376
valset = TTSDataset(audiopaths_and_text=args.validation_files, **vars(args))
369377

370378
if distributed_run:
371-
train_sampler, shuffle = DistributedSampler(trainset), False
379+
train_sampler = RepeatedDistributedSampler(args.trainloader_repeats,
380+
trainset, drop_last=True)
381+
val_sampler = DistributedSampler(valset)
382+
shuffle = False
372383
else:
373-
train_sampler, shuffle = None, True
384+
train_sampler, val_sampler, shuffle = None, None, True
374385

375386
# 4 workers are optimal on DGX-1 (from epoch 2 onwards)
376-
train_loader = DataLoader(trainset, num_workers=4, shuffle=shuffle,
377-
sampler=train_sampler, batch_size=args.batch_size,
378-
pin_memory=True, persistent_workers=True,
379-
drop_last=True, collate_fn=collate_fn)
380-
387+
kw = {'num_workers': args.num_workers, 'batch_size': args.batch_size,
388+
'collate_fn': collate_fn}
389+
train_loader = RepeatedDataLoader(args.trainloader_repeats, trainset,
390+
shuffle=shuffle, drop_last=True,
391+
sampler=train_sampler, pin_memory=True,
392+
persistent_workers=True, **kw)
393+
val_loader = DataLoader(valset, shuffle=False, sampler=val_sampler,
394+
pin_memory=False, **kw)
381395
if args.ema_decay:
382396
mt_ema_params = init_multi_tensor_ema(model, ema_model)
383397

384398
model.train()
385-
386399
bmark_stats = BenchmarkStats()
387400

388401
torch.cuda.synchronize()
@@ -397,22 +410,15 @@ def main():
397410
if distributed_run:
398411
train_loader.sampler.set_epoch(epoch)
399412

400-
accumulated_steps = 0
401413
iter_loss = 0
402414
iter_num_frames = 0
403415
iter_meta = {}
404416
iter_start_time = time.perf_counter()
405417

406-
epoch_iter = 0
407-
num_iters = len(train_loader) // args.grad_accumulation
408-
for batch in train_loader:
409-
410-
if accumulated_steps == 0:
411-
if epoch_iter == num_iters:
412-
break
413-
total_iter += 1
414-
epoch_iter += 1
415-
418+
epoch_iter = 1
419+
for batch, accum_step in zip(train_loader,
420+
cycle(range(args.grad_accumulation))):
421+
if accum_step == 0:
416422
adjust_learning_rate(total_iter, optimizer, args.learning_rate,
417423
args.warmup_steps)
418424

@@ -461,12 +467,11 @@ def main():
461467
if np.isnan(reduced_loss):
462468
raise Exception("loss is NaN")
463469

464-
accumulated_steps += 1
465470
iter_loss += reduced_loss
466471
iter_num_frames += reduced_num_frames
467472
iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}
468473

469-
if accumulated_steps % args.grad_accumulation == 0:
474+
if accum_step % args.grad_accumulation == 0:
470475

471476
logger.log_grads_tb(total_iter, model)
472477
if args.amp:
@@ -491,6 +496,7 @@ def main():
491496
epoch_num_frames += iter_num_frames
492497
epoch_mel_loss += iter_mel_loss
493498

499+
num_iters = len(train_loader) // args.grad_accumulation
494500
log((epoch, epoch_iter, num_iters), tb_total_steps=total_iter,
495501
subset='train', data=OrderedDict([
496502
('loss', iter_loss),
@@ -502,12 +508,16 @@ def main():
502508
('lrate', optimizer.param_groups[0]['lr'])]),
503509
)
504510

505-
accumulated_steps = 0
506511
iter_loss = 0
507512
iter_num_frames = 0
508513
iter_meta = {}
509514
iter_start_time = time.perf_counter()
510515

516+
if epoch_iter == num_iters:
517+
break
518+
epoch_iter += 1
519+
total_iter += 1
520+
511521
# Finished epoch
512522
epoch_loss /= epoch_iter
513523
epoch_mel_loss /= epoch_iter
@@ -523,13 +533,13 @@ def main():
523533
bmark_stats.update(epoch_num_frames, epoch_loss, epoch_mel_loss,
524534
epoch_time)
525535

526-
validate(model, epoch, total_iter, criterion, valset, args.batch_size,
527-
collate_fn, distributed_run, batch_to_gpu)
536+
if epoch % args.validation_freq == 0:
537+
validate(model, epoch, total_iter, criterion, val_loader,
538+
distributed_run, batch_to_gpu)
528539

529-
if args.ema_decay > 0:
530-
validate(ema_model, epoch, total_iter, criterion, valset,
531-
args.batch_size, collate_fn, distributed_run, batch_to_gpu,
532-
ema=True)
540+
if args.ema_decay > 0:
541+
validate(ema_model, epoch, total_iter, criterion, val_loader,
542+
distributed_run, batch_to_gpu, ema=True)
533543

534544
# save before making sched.step() for proper loading of LR
535545
checkpointer.maybe_save(args, model, ema_model, optimizer, scaler,
@@ -538,10 +548,11 @@ def main():
538548

539549
# Finished training
540550
if len(bmark_stats) > 0:
541-
log((), tb_total_steps=None, subset='train_avg', data=bmark_stats.get(args.benchmark_epochs_num))
551+
log((), tb_total_steps=None, subset='train_avg',
552+
data=bmark_stats.get(args.benchmark_epochs_num))
542553

543-
validate(model, None, total_iter, criterion, valset, args.batch_size,
544-
collate_fn, distributed_run, batch_to_gpu)
554+
validate(model, None, total_iter, criterion, val_loader, distributed_run,
555+
batch_to_gpu)
545556

546557

547558
if __name__ == '__main__':

PyTorch/SpeechSynthesis/HiFi-GAN/common/filter_warnings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""Mutes known and unrelated PyTorch warnings.
16+
17+
The warnings module keeps a list of filters. Importing it as late as possible
18+
prevents its filters from being overriden.
19+
"""
20+
1521
import warnings
1622

1723

0 commit comments

Comments
 (0)