Skip to content

Commit bf00fe1

Browse files
tomsiadevnv-kkudrynski
authored andcommitted
[SIM/TF2] Release new version of SIM model with prebatching support
1 parent c903326 commit bf00fe1

10 files changed

Lines changed: 504 additions & 451 deletions

File tree

TensorFlow2/Recommendation/SIM/.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,3 @@
1515
.ipynb_checkpoints/
1616
.idea/
1717
__pycache__
18-
results/

TensorFlow2/Recommendation/SIM/README.md

Lines changed: 239 additions & 329 deletions
Large diffs are not rendered by default.

TensorFlow2/Recommendation/SIM/main.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def init_logger(results_dir, filename):
9393

9494

9595
# In the future, select one of available dataloaders there (tfrecord, csv, etc...)
96-
def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length, prefetch_size, repeat_count=0,
97-
drop_remainder=False, amp=False, disable_cache=False):
96+
def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length, prefetch_size, num_parallel_calls=None, repeat_count=0,
97+
drop_remainder=False, amp=False, disable_cache=False, prebatch_size=0):
9898
return get_dataloader_tfrecord(
9999
paths,
100100
feature_spec=feature_spec,
@@ -105,7 +105,9 @@ def get_data_iterator(paths, feature_spec, batch_size, num_gpus, long_seq_length
105105
drop_remainder=drop_remainder,
106106
repeat_count=repeat_count,
107107
disable_cache=disable_cache,
108-
prefetch_buffer_size=prefetch_size
108+
prefetch_buffer_size=prefetch_size,
109+
num_parallel_calls=num_parallel_calls,
110+
prebatch_size=prebatch_size
109111
)
110112

111113

@@ -243,10 +245,24 @@ def eval(model_fn, data_iterator, num_thresholds=8000, prefix=""):
243245
local_targets.append(targets)
244246
local_total_losses.append(loss_dict["total_loss"])
245247

246-
# concat all local variables into a single tensor
247-
logits = tf.concat(local_logits, 0)
248-
targets = tf.concat(local_targets, 0)
249-
total_losses = tf.concat(local_total_losses, 0)
248+
locals = [local_logits, local_targets, local_total_losses]
249+
for i, local in enumerate(locals):
250+
251+
# wrap empty lists in tensor to allow tf.concat
252+
if len(local) == 0:
253+
local = tf.constant(local)
254+
255+
# concat all local variables into a single tensor
256+
local = tf.concat(local, 0)
257+
258+
# for single element lists, tf.concat will produce shape=() instead of shape=(1,).
259+
# reshape it for hvd.allgather to work
260+
if len(local.shape) == 0:
261+
local = tf.reshape(local, -1)
262+
263+
locals[i] = local
264+
265+
logits, targets, total_losses = locals
250266

251267
if distributed:
252268
# gather from all nodes
@@ -455,6 +471,9 @@ def inference(model, data_iterator, benchmark, performance_calculator):
455471
@click.option(
456472
"--global_batch_size", default=131072, help="Batch size used to train/eval the model.", type=int
457473
)
474+
@click.option(
475+
"--num_parallel_calls", default=None, help="Parallelism level for tf.data API. If None, heuristic based on number of CPUs and number of GPUs will be used."
476+
)
458477
@click.option(
459478
"--epochs", default=3, help="Train for the following number of epochs.", type=int
460479
)
@@ -521,20 +540,23 @@ def inference(model, data_iterator, benchmark, performance_calculator):
521540
)
522541
@click.option(
523542
"--prefetch_train_size",
524-
default=-1,
543+
default=10,
525544
help="Number of batches to prefetch in training. "
526-
"If == 0: No prefetching is done. "
527-
"If < 0: Prefetch size is set to train_dataset_size // global_batch_size. ",
528545
)
529546
@click.option(
530547
"--prefetch_test_size",
531548
default=2,
532549
help="Number of batches to prefetch in testing"
533550
)
534551
@click.option(
535-
"--train_dataset_size",
536-
default=11796480,
537-
help="Number of train samples. Used to set prefetching size (see --prefetch_train_size for more information."
552+
"--prebatch_train_size",
553+
default=0,
554+
help="Information about batch size applied during preprocessing to train dataset"
555+
)
556+
@click.option(
557+
"--prebatch_test_size",
558+
default=0,
559+
help="Information about batch size applied during preprocessing to test dataset"
538560
)
539561
def main(
540562
mode: str,
@@ -554,6 +576,7 @@ def main(
554576
weight_decay: float,
555577
embedding_dim: int,
556578
global_batch_size: int,
579+
num_parallel_calls: int,
557580
epochs: int,
558581
disable_cache: bool,
559582
drop_remainder: bool,
@@ -570,7 +593,8 @@ def main(
570593
intra_op_parallelism: int,
571594
prefetch_train_size: int,
572595
prefetch_test_size: int,
573-
train_dataset_size: int
596+
prebatch_train_size: int,
597+
prebatch_test_size: int
574598
):
575599
hvd.init()
576600

@@ -636,20 +660,19 @@ def main(
636660
# since each tfrecord file must include all of the features, it is enough to read first chunk for each split.
637661
train_files = [dataset_dir / file for file in feature_spec.source_spec[TRAIN_MAPPING][0][FILES_SELECTOR]]
638662

639-
if prefetch_train_size < 0:
640-
prefetch_train_size = train_dataset_size // global_batch_size
641-
642663
data_iterator_train = get_data_iterator(
643664
train_files, feature_spec, batch_size, num_gpus, long_seq_length,
644665
repeat_count=repeat_count, drop_remainder=drop_remainder,
645-
amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_train_size
666+
amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_train_size,
667+
num_parallel_calls=num_parallel_calls, prebatch_size=prebatch_train_size
646668
)
647669

648670
if mode == "train":
649671
test_files = [dataset_dir / file for file in feature_spec.source_spec[TEST_MAPPING][0][FILES_SELECTOR]]
650672
data_iterator_test = get_data_iterator(
651673
test_files, feature_spec, batch_size, num_gpus, long_seq_length,
652-
amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_test_size
674+
amp=amp, disable_cache=disable_cache, prefetch_size=prefetch_test_size, num_parallel_calls=num_parallel_calls,
675+
prebatch_size=prebatch_test_size
653676
)
654677
else:
655678
data_iterator_test = [] # otherwise not used
@@ -689,4 +712,4 @@ def main(
689712

690713

691714
if __name__ == "__main__":
692-
main()
715+
main()

TensorFlow2/Recommendation/SIM/preprocessing/ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,39 @@ def _preserve_data(offsets, values, new_values):
7373
new_values[i] = values[rowid]
7474

7575

76+
@numba.cuda.jit
77+
def _slice_rjust(max_elements, offsets, elements, new_offsets, new_elements):
78+
rowid = numba.cuda.grid(1)
79+
if rowid < new_offsets.size - 1:
80+
row_size = min(offsets[rowid + 1] - offsets[rowid], max_elements)
81+
offset = offsets[rowid + 1] - row_size
82+
new_start = new_offsets[rowid + 1] - row_size
83+
84+
for i in range(row_size):
85+
new_elements[new_start + i] = elements[offset + i]
86+
87+
88+
def slice_and_pad_left(seq_col, max_elements, pad_value=0):
89+
c = seq_col._column
90+
offsets = c.offsets.values
91+
elements = c.elements.values
92+
93+
threads = THREADS
94+
blocks = (offsets.size + threads - 1) // threads
95+
96+
new_offsets = cupy.arange(offsets.size, dtype=offsets.dtype) * max_elements
97+
98+
new_elements = cupy.full(
99+
new_offsets[-1].item(), fill_value=pad_value, dtype=elements.dtype
100+
)
101+
_slice_rjust[blocks, threads](
102+
max_elements, offsets, elements, new_offsets, new_elements
103+
)
104+
105+
new_col = nvt_build_list_column(new_elements, new_offsets)
106+
return new_col
107+
108+
76109
class ExplodeSequence:
77110
"""
78111
For each row create a new one with a subsequence of the original list columns.

0 commit comments

Comments
 (0)