@@ -93,8 +93,8 @@ def init_logger(results_dir, filename):
9393
9494
9595# In the future, select one of available dataloaders there (tfrecord, csv, etc...)
96- def get_data_iterator (paths , feature_spec , batch_size , num_gpus , long_seq_length , prefetch_size , repeat_count = 0 ,
97- drop_remainder = False , amp = False , disable_cache = False ):
96+ def get_data_iterator (paths , feature_spec , batch_size , num_gpus , long_seq_length , prefetch_size , num_parallel_calls = None , repeat_count = 0 ,
97+ drop_remainder = False , amp = False , disable_cache = False , prebatch_size = 0 ):
9898 return get_dataloader_tfrecord (
9999 paths ,
100100 feature_spec = feature_spec ,
@@ -105,7 +105,9 @@ def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length
105105 drop_remainder = drop_remainder ,
106106 repeat_count = repeat_count ,
107107 disable_cache = disable_cache ,
108- prefetch_buffer_size = prefetch_size
108+ prefetch_buffer_size = prefetch_size ,
109+ num_parallel_calls = num_parallel_calls ,
110+ prebatch_size = prebatch_size
109111 )
110112
111113
@@ -243,10 +245,24 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
243245 local_targets .append (targets )
244246 local_total_losses .append (loss_dict ["total_loss" ])
245247
246- # concat all local variables into a single tensor
247- logits = tf .concat (local_logits , 0 )
248- targets = tf .concat (local_targets , 0 )
249- total_losses = tf .concat (local_total_losses , 0 )
248+ locals = [local_logits , local_targets , local_total_losses ]
249+ for i , local in enumerate (locals ):
250+
251+ # wrap empty lists in tensor to allow tf.concat
252+ if len (local ) == 0 :
253+ local = tf .constant (local )
254+
255+ # concat all local variables into a single tensor
256+ local = tf .concat (local , 0 )
257+
258+ # for single element lists, tf.concat will produce shape=() instead of shape=(1,).
259+ # reshape it for hvd.allgather to work
260+ if len (local .shape ) == 0 :
261+ local = tf .reshape (local , - 1 )
262+
263+ locals [i ] = local
264+
265+ logits , targets , total_losses = locals
250266
251267 if distributed :
252268 # gather from all nodes
@@ -455,6 +471,9 @@ def inference(model, data_iterator, benchmark, performance_calculator):
455471@click .option (
456472 "--global_batch_size" , default = 131072 , help = "Batch size used to train/eval the model." , type = int
457473)
474+ @click .option (
475+ "--num_parallel_calls" , default = None , help = "Parallelism level for tf.data API. If None, heuristic based on number of CPUs and number of GPUs will be used."
476+ )
458477@click .option (
459478 "--epochs" , default = 3 , help = "Train for the following number of epochs." , type = int
460479)
@@ -521,20 +540,23 @@ def inference(model, data_iterator, benchmark, performance_calculator):
521540)
522541@click .option (
523542 "--prefetch_train_size" ,
524- default = - 1 ,
543+ default = 10 ,
525544 help = "Number of batches to prefetch in training. "
526- "If == 0: No prefetching is done. "
527- "If < 0: Prefetch size is set to train_dataset_size // global_batch_size. " ,
528545)
529546@click .option (
530547 "--prefetch_test_size" ,
531548 default = 2 ,
532549 help = "Number of batches to prefetch in testing"
533550)
534551@click .option (
535- "--train_dataset_size" ,
536- default = 11796480 ,
537- help = "Number of train samples. Used to set prefetching size (see --prefetch_train_size for more information."
552+ "--prebatch_train_size" ,
553+ default = 0 ,
554+ help = "Information about batch size applied during preprocessing to train dataset"
555+ )
556+ @click .option (
557+ "--prebatch_test_size" ,
558+ default = 0 ,
559+ help = "Information about batch size applied during preprocessing to test dataset"
538560)
539561def main (
540562 mode : str ,
@@ -554,6 +576,7 @@ def main(
554576 weight_decay : float ,
555577 embedding_dim : int ,
556578 global_batch_size : int ,
579+ num_parallel_calls : int ,
557580 epochs : int ,
558581 disable_cache : bool ,
559582 drop_remainder : bool ,
@@ -570,7 +593,8 @@ def main(
570593 intra_op_parallelism : int ,
571594 prefetch_train_size : int ,
572595 prefetch_test_size : int ,
573- train_dataset_size : int
596+ prebatch_train_size : int ,
597+ prebatch_test_size : int
574598):
575599 hvd .init ()
576600
@@ -636,20 +660,19 @@ def main(
636660 # since each tfrecord file must include all of the features, it is enough to read first chunk for each split.
637661 train_files = [dataset_dir / file for file in feature_spec .source_spec [TRAIN_MAPPING ][0 ][FILES_SELECTOR ]]
638662
639- if prefetch_train_size < 0 :
640- prefetch_train_size = train_dataset_size // global_batch_size
641-
642663 data_iterator_train = get_data_iterator (
643664 train_files , feature_spec , batch_size , num_gpus , long_seq_length ,
644665 repeat_count = repeat_count , drop_remainder = drop_remainder ,
645- amp = amp , disable_cache = disable_cache , prefetch_size = prefetch_train_size
666+ amp = amp , disable_cache = disable_cache , prefetch_size = prefetch_train_size ,
667+ num_parallel_calls = num_parallel_calls , prebatch_size = prebatch_train_size
646668 )
647669
648670 if mode == "train" :
649671 test_files = [dataset_dir / file for file in feature_spec .source_spec [TEST_MAPPING ][0 ][FILES_SELECTOR ]]
650672 data_iterator_test = get_data_iterator (
651673 test_files , feature_spec , batch_size , num_gpus , long_seq_length ,
652- amp = amp , disable_cache = disable_cache , prefetch_size = prefetch_test_size
674+ amp = amp , disable_cache = disable_cache , prefetch_size = prefetch_test_size , num_parallel_calls = num_parallel_calls ,
675+ prebatch_size = prebatch_test_size
653676 )
654677 else :
655678 data_iterator_test = [] # otherwise not used
@@ -689,4 +712,4 @@ def main(
689712
690713
691714if __name__ == "__main__" :
692- main ()
715+ main ()
0 commit comments