@@ -434,16 +434,17 @@ def main():
434434
435435 metrics .finish_iter () # done accumulating
436436 if iters_all % args .step_logs_interval == 0 :
437- logger .log ((epoch , iter_ , iters_num ), metrics ,
438- scope = 'train' , tb_iter = iters_all )
437+ logger .log ((epoch , iter_ , iters_num ), metrics , scope = 'train' ,
438+ tb_iter = iters_all , flush_log = True )
439439
440440 assert is_last_accum_step
441441 metrics .finish_epoch ()
442442 logger .log ((epoch ,), metrics , scope = 'train_avg' , flush_log = True )
443443
444444 if epoch % args .validation_interval == 0 :
445445 validate (args , gen , mel_spec , mpd , msd , val_loader , val_metrics )
446- logger .log ((epoch ,), val_metrics , scope = 'val' , tb_iter = iters_all )
446+ logger .log ((epoch ,), val_metrics , scope = 'val' , tb_iter = iters_all ,
447+ flush_log = True )
447448
448449 # validation samples
449450 if epoch % args .samples_interval == 0 and args .local_rank == 0 :
@@ -477,6 +478,7 @@ def main():
477478 gen , mpd , msd , optim_g , optim_d , scaler_g , scaler_d , epoch ,
478479 train_state , args , gen_config , train_setup ,
479480 gen_ema = gen_ema , mpd_ema = mpd_ema , msd_ema = msd_ema )
481+ logger .flush ()
480482
481483 sched_g .step ()
482484 sched_d .step ()
@@ -488,10 +490,10 @@ def main():
488490
489491 # finished training
490492 if epochs_done > 0 :
491- logger .log ((), metrics , scope = 'train_benchmark' )
493+ logger .log ((), metrics , scope = 'train_benchmark' , flush_log = True )
492494 if epoch % args .validation_interval != 0 : # val metrics are not up-to-date
493495 validate (args , gen , mel_spec , mpd , msd , val_loader , val_metrics )
494- logger .log ((), val_metrics , scope = 'val' )
496+ logger .log ((), val_metrics , scope = 'val' , flush_log = True )
495497 else :
496498 print_once (f'Finished without training after epoch { args .epochs } .' )
497499
0 commit comments