3030import os
3131import time
3232from collections import defaultdict , OrderedDict
33+ from itertools import cycle
3334
3435import numpy as np
3536import torch
4344import common .tb_dllogger as logger
4445import models
4546from common .tb_dllogger import log
47+ from common .repeated_dataloader import (RepeatedDataLoader ,
48+ RepeatedDistributedSampler )
4649from common .text import cmudict
4750from common .utils import BenchmarkStats , Checkpointer , prepare_tmp
4851from fastpitch .attn_loss_function import AttentionBinarizationLoss
@@ -90,6 +93,8 @@ def parse_args(parser):
9093 help = 'Gradually increase the hard attention loss term' )
9194 train .add_argument ('--benchmark-epochs-num' , type = int , default = 20 ,
9295 help = 'Number of epochs for calculating final stats' )
96+ train .add_argument ('--validation-freq' , type = int , default = 1 ,
97+ help = 'Validate every N epochs to use less compute' )
9398
9499 opt = parser .add_argument_group ('optimization setup' )
95100 opt .add_argument ('--optimizer' , type = str , default = 'lamb' ,
@@ -132,6 +137,10 @@ def parse_args(parser):
132137 help = 'Capture leading silence with a space token' )
133138 data .add_argument ('--append-space-to-text' , action = 'store_true' ,
134139 help = 'Capture trailing silence with a space token' )
140+ data .add_argument ('--num-workers' , type = int , default = 6 ,
141+ help = 'Subprocesses for train and val DataLoaders' )
142+ data .add_argument ('--trainloader-repeats' , type = int , default = 100 ,
143+ help = 'Repeats the dataset to prolong epochs' )
135144
136145 cond = parser .add_argument_group ('data for conditioning' )
137146 cond .add_argument ('--n-speakers' , type = int , default = 1 ,
@@ -194,19 +203,13 @@ def init_distributed(args, world_size, rank):
194203 print ("Done initializing distributed training" )
195204
196205
197- def validate (model , epoch , total_iter , criterion , valset , batch_size ,
198- collate_fn , distributed_run , batch_to_gpu , ema = False ):
199- """Handles all the validation scoring and printing"""
206+ def validate (model , epoch , total_iter , criterion , val_loader , distributed_run ,
207+ batch_to_gpu , ema = False ):
200208 was_training = model .training
201209 model .eval ()
202210
203211 tik = time .perf_counter ()
204212 with torch .no_grad ():
205- val_sampler = DistributedSampler (valset ) if distributed_run else None
206- val_loader = DataLoader (valset , num_workers = 4 , shuffle = False ,
207- sampler = val_sampler ,
208- batch_size = batch_size , pin_memory = False ,
209- collate_fn = collate_fn )
210213 val_meta = defaultdict (float )
211214 val_num_frames = 0
212215 for i , batch in enumerate (val_loader ):
@@ -221,9 +224,9 @@ def validate(model, epoch, total_iter, criterion, valset, batch_size,
221224 else :
222225 for k , v in meta .items ():
223226 val_meta [k ] += v
224- val_num_frames = num_frames .item ()
227+ val_num_frames + = num_frames .item ()
225228
226- val_meta = {k : v / len (valset ) for k , v in val_meta .items ()}
229+ val_meta = {k : v / len (val_loader . dataset ) for k , v in val_meta .items ()}
227230
228231 val_meta ['took' ] = time .perf_counter () - tik
229232
@@ -232,7 +235,7 @@ def validate(model, epoch, total_iter, criterion, valset, batch_size,
232235 data = OrderedDict ([
233236 ('loss' , val_meta ['loss' ].item ()),
234237 ('mel_loss' , val_meta ['mel_loss' ].item ()),
235- ('frames/s' , num_frames . item () / val_meta ['took' ]),
238+ ('frames/s' , val_num_frames / val_meta ['took' ]),
236239 ('took' , val_meta ['took' ])]),
237240 )
238241
@@ -313,6 +316,11 @@ def main():
313316
314317 if distributed_run :
315318 init_distributed (args , args .world_size , args .local_rank )
319+ else :
320+ if args .trainloader_repeats > 1 :
321+ print ('WARNING: Disabled --trainloader-repeats, supported only for'
322+ ' multi-GPU data loading.' )
323+ args .trainloader_repeats = 1
316324
317325 device = torch .device ('cuda' if args .cuda else 'cpu' )
318326 model_config = models .get_model_config ('FastPitch' , args )
@@ -345,7 +353,7 @@ def main():
345353 model , device_ids = [args .local_rank ], output_device = args .local_rank ,
346354 find_unused_parameters = True )
347355
348- train_state = {'epoch' : 1 , 'total_iter' : 0 }
356+ train_state = {'epoch' : 1 , 'total_iter' : 1 }
349357 checkpointer = Checkpointer (args .output , args .keep_milestones )
350358
351359 checkpointer .maybe_load (model , optimizer , scaler , train_state , args ,
@@ -368,21 +376,26 @@ def main():
368376 valset = TTSDataset (audiopaths_and_text = args .validation_files , ** vars (args ))
369377
370378 if distributed_run :
371- train_sampler , shuffle = DistributedSampler (trainset ), False
379+ train_sampler = RepeatedDistributedSampler (args .trainloader_repeats ,
380+ trainset , drop_last = True )
381+ val_sampler = DistributedSampler (valset )
382+ shuffle = False
372383 else :
373- train_sampler , shuffle = None , True
384+ train_sampler , val_sampler , shuffle = None , None , True
374385
375386 # 4 workers are optimal on DGX-1 (from epoch 2 onwards)
376- train_loader = DataLoader (trainset , num_workers = 4 , shuffle = shuffle ,
377- sampler = train_sampler , batch_size = args .batch_size ,
378- pin_memory = True , persistent_workers = True ,
379- drop_last = True , collate_fn = collate_fn )
380-
387+ kw = {'num_workers' : args .num_workers , 'batch_size' : args .batch_size ,
388+ 'collate_fn' : collate_fn }
389+ train_loader = RepeatedDataLoader (args .trainloader_repeats , trainset ,
390+ shuffle = shuffle , drop_last = True ,
391+ sampler = train_sampler , pin_memory = True ,
392+ persistent_workers = True , ** kw )
393+ val_loader = DataLoader (valset , shuffle = False , sampler = val_sampler ,
394+ pin_memory = False , ** kw )
381395 if args .ema_decay :
382396 mt_ema_params = init_multi_tensor_ema (model , ema_model )
383397
384398 model .train ()
385-
386399 bmark_stats = BenchmarkStats ()
387400
388401 torch .cuda .synchronize ()
@@ -397,22 +410,15 @@ def main():
397410 if distributed_run :
398411 train_loader .sampler .set_epoch (epoch )
399412
400- accumulated_steps = 0
401413 iter_loss = 0
402414 iter_num_frames = 0
403415 iter_meta = {}
404416 iter_start_time = time .perf_counter ()
405417
406- epoch_iter = 0
407- num_iters = len (train_loader ) // args .grad_accumulation
408- for batch in train_loader :
409-
410- if accumulated_steps == 0 :
411- if epoch_iter == num_iters :
412- break
413- total_iter += 1
414- epoch_iter += 1
415-
418+ epoch_iter = 1
419+ for batch , accum_step in zip (train_loader ,
420+ cycle (range (args .grad_accumulation ))):
421+ if accum_step == 0 :
416422 adjust_learning_rate (total_iter , optimizer , args .learning_rate ,
417423 args .warmup_steps )
418424
@@ -461,12 +467,11 @@ def main():
461467 if np .isnan (reduced_loss ):
462468 raise Exception ("loss is NaN" )
463469
464- accumulated_steps += 1
465470 iter_loss += reduced_loss
466471 iter_num_frames += reduced_num_frames
467472 iter_meta = {k : iter_meta .get (k , 0 ) + meta .get (k , 0 ) for k in meta }
468473
469- if accumulated_steps % args .grad_accumulation == 0 :
474+ if accum_step % args .grad_accumulation == 0 :
470475
471476 logger .log_grads_tb (total_iter , model )
472477 if args .amp :
@@ -491,6 +496,7 @@ def main():
491496 epoch_num_frames += iter_num_frames
492497 epoch_mel_loss += iter_mel_loss
493498
499+ num_iters = len (train_loader ) // args .grad_accumulation
494500 log ((epoch , epoch_iter , num_iters ), tb_total_steps = total_iter ,
495501 subset = 'train' , data = OrderedDict ([
496502 ('loss' , iter_loss ),
@@ -502,12 +508,16 @@ def main():
502508 ('lrate' , optimizer .param_groups [0 ]['lr' ])]),
503509 )
504510
505- accumulated_steps = 0
506511 iter_loss = 0
507512 iter_num_frames = 0
508513 iter_meta = {}
509514 iter_start_time = time .perf_counter ()
510515
516+ if epoch_iter == num_iters :
517+ break
518+ epoch_iter += 1
519+ total_iter += 1
520+
511521 # Finished epoch
512522 epoch_loss /= epoch_iter
513523 epoch_mel_loss /= epoch_iter
@@ -523,13 +533,13 @@ def main():
523533 bmark_stats .update (epoch_num_frames , epoch_loss , epoch_mel_loss ,
524534 epoch_time )
525535
526- validate (model , epoch , total_iter , criterion , valset , args .batch_size ,
527- collate_fn , distributed_run , batch_to_gpu )
536+ if epoch % args .validation_freq == 0 :
537+ validate (model , epoch , total_iter , criterion , val_loader ,
538+ distributed_run , batch_to_gpu )
528539
529- if args .ema_decay > 0 :
530- validate (ema_model , epoch , total_iter , criterion , valset ,
531- args .batch_size , collate_fn , distributed_run , batch_to_gpu ,
532- ema = True )
540+ if args .ema_decay > 0 :
541+ validate (ema_model , epoch , total_iter , criterion , val_loader ,
542+ distributed_run , batch_to_gpu , ema = True )
533543
534544 # save before making sched.step() for proper loading of LR
535545 checkpointer .maybe_save (args , model , ema_model , optimizer , scaler ,
@@ -538,10 +548,11 @@ def main():
538548
539549 # Finished training
540550 if len (bmark_stats ) > 0 :
541- log ((), tb_total_steps = None , subset = 'train_avg' , data = bmark_stats .get (args .benchmark_epochs_num ))
551+ log ((), tb_total_steps = None , subset = 'train_avg' ,
552+ data = bmark_stats .get (args .benchmark_epochs_num ))
542553
543- validate (model , None , total_iter , criterion , valset , args . batch_size ,
544- collate_fn , distributed_run , batch_to_gpu )
554+ validate (model , None , total_iter , criterion , val_loader , distributed_run ,
555+ batch_to_gpu )
545556
546557
547558if __name__ == '__main__' :
0 commit comments