|
20 | 20 | import json |
21 | 21 |
|
22 | 22 | from distributed_embeddings.python.layers import dist_model_parallel as dmp |
23 | | -from distributed_embeddings.python.layers import embedding |
24 | 23 |
|
25 | 24 | from utils.checkpointing import get_variable_path |
26 | 25 |
|
|
29 | 28 |
|
30 | 29 | sparse_model_parameters = ['use_mde_embeddings', 'embedding_dim', 'column_slice_threshold', |
31 | 30 | 'embedding_zeros_initializer', 'embedding_trainable', 'categorical_cardinalities', |
32 | | - 'concat_embedding', 'cpu_offloading_threshold_gb'] |
| 31 | + 'concat_embedding', 'cpu_offloading_threshold_gb', |
| 32 | + 'data_parallel_input', 'row_slice_threshold', 'data_parallel_threshold'] |
| 33 | + |
| 34 | +def _gigabytes_to_elements(gb, dtype=tf.float32): |
| 35 | + if gb is None: |
| 36 | + return None |
| 37 | + |
| 38 | + if dtype == tf.float32: |
| 39 | + bytes_per_element = 4 |
| 40 | + else: |
| 41 | + raise ValueError(f'Unsupported dtype: {dtype}') |
| 42 | + |
| 43 | + return gb * 10**9 / bytes_per_element |
33 | 44 |
|
34 | 45 | class SparseModel(tf.keras.Model): |
35 | 46 | def __init__(self, **kwargs): |
@@ -61,21 +72,21 @@ def _create_embeddings(self): |
61 | 72 | for table_size, dim in zip(self.categorical_cardinalities, self.embedding_dim): |
62 | 73 | if hvd.rank() == 0: |
63 | 74 | print(f'Creating embedding with size: {table_size} {dim}') |
64 | | - if self.use_mde_embeddings: |
65 | | - e = embedding.Embedding(input_dim=table_size, output_dim=dim, |
66 | | - combiner='sum', embeddings_initializer=initializer_cls()) |
67 | | - else: |
68 | | - e = tf.keras.layers.Embedding(input_dim=table_size, output_dim=dim, |
69 | | - embeddings_initializer=initializer_cls()) |
| 75 | + e = tf.keras.layers.Embedding(input_dim=table_size, output_dim=dim, |
| 76 | + embeddings_initializer=initializer_cls()) |
70 | 77 | self.embedding_layers.append(e) |
71 | 78 |
|
| 79 | + gpu_size = _gigabytes_to_elements(self.cpu_offloading_threshold_gb) |
72 | 80 | self.embedding = dmp.DistributedEmbedding(self.embedding_layers, |
73 | 81 | strategy='memory_balanced', |
74 | | - dp_input=False, |
75 | | - column_slice_threshold=self.column_slice_threshold) |
| 82 | + dp_input=self.data_parallel_input, |
| 83 | + column_slice_threshold=self.column_slice_threshold, |
| 84 | + row_slice_threshold=self.row_slice_threshold, |
| 85 | + data_parallel_threshold=self.data_parallel_threshold, |
| 86 | + gpu_embedding_size=gpu_size) |
76 | 87 |
|
77 | 88 | def get_local_table_ids(self, rank): |
78 | | - if self.use_concat_embedding: |
| 89 | + if self.use_concat_embedding or self.data_parallel_input: |
79 | 90 | return list(range(self.num_all_categorical_features)) |
80 | 91 | else: |
81 | 92 | return self.embedding.strategy.input_ids_list[rank] |
@@ -127,4 +138,10 @@ def save_config(self, path): |
127 | 138 | def from_config(path): |
128 | 139 | with open(path) as f: |
129 | 140 | config = json.load(fp=f) |
| 141 | + if 'data_parallel_input' not in config: |
| 142 | + config['data_parallel_input'] = False |
| 143 | + if 'row_slice_threshold' not in config: |
| 144 | + config['row_slice_threshold'] = None |
| 145 | + if 'data_parallel_threshold' not in config: |
| 146 | + config['data_parallel_threshold'] = None |
130 | 147 | return SparseModel(**config) |
0 commit comments