|
32 | 32 | "import tensorflow_hub as hub\n", |
33 | 33 | "from datetime import datetime\n", |
34 | 34 | "import requests\n", |
35 | | - "import tf_keras\n", |
36 | 35 | "print(\"We are using Tensorflow version: \", tf.__version__)" |
37 | 36 | ] |
38 | 37 | }, |
|
99 | 98 | "outputs": [], |
100 | 99 | "source": [ |
101 | 100 | "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", |
102 | | - "data_root = tf_keras.utils.get_file(\n", |
| 101 | + "data_root = tf.keras.utils.get_file(\n", |
103 | 102 | " 'flower_photos',\n", |
104 | 103 | " 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n", |
105 | 104 | " untar=True)\n", |
|
108 | 107 | "img_height = 224\n", |
109 | 108 | "img_width = 224\n", |
110 | 109 | "\n", |
111 | | - "train_ds = tf_keras.utils.image_dataset_from_directory(\n", |
| 110 | + "train_ds = tf.keras.utils.image_dataset_from_directory(\n", |
112 | 111 | " str(data_root),\n", |
113 | 112 | " validation_split=0.2,\n", |
114 | 113 | " subset=\"training\",\n", |
|
117 | 116 | " batch_size=batch_size\n", |
118 | 117 | ")\n", |
119 | 118 | "\n", |
120 | | - "val_ds = tf_keras.utils.image_dataset_from_directory(\n", |
| 119 | + "val_ds = tf.keras.utils.image_dataset_from_directory(\n", |
121 | 120 | " str(data_root),\n", |
122 | 121 | " validation_split=0.2,\n", |
123 | 122 | " subset=\"validation\",\n", |
|
147 | 146 | "metadata": {}, |
148 | 147 | "outputs": [], |
149 | 148 | "source": [ |
150 | | - "normalization_layer = tf_keras.layers.Rescaling(1./255)\n", |
| 149 | + "normalization_layer = tf.keras.layers.Rescaling(1./255)\n", |
151 | 150 | "train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.\n", |
152 | 151 | "val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.\n", |
153 | 152 | "\n", |
|
221 | 220 | "id": "70b3eb9b", |
222 | 221 | "metadata": {}, |
223 | 222 | "source": [ |
224 | | - "Attach the last fully connected classification layer in a **tf_keras.Sequential** model." |
| 223 | + "Attach the last fully connected classification layer in a **tf.keras.Sequential** model." |
225 | 224 | ] |
226 | 225 | }, |
227 | 226 | { |
|
233 | 232 | "source": [ |
234 | 233 | "num_classes = len(class_names)\n", |
235 | 234 | "\n", |
236 | | - "fp32_model = tf_keras.Sequential([\n", |
| 235 | + "fp32_model = tf.keras.Sequential([\n", |
237 | 236 | " feature_extractor_layer,\n", |
238 | | - " tf_keras.layers.Dense(num_classes)\n", |
| 237 | + " tf.keras.layers.Dense(num_classes)\n", |
239 | 238 | "])\n", |
240 | 239 | "\n", |
241 | 240 | "if arch == 'SPR':\n", |
242 | 241 | " # Create a deep copy of the model to train the bf16 model separately to compare accuracy\n", |
243 | | - " bf16_model = tf_keras.models.clone_model(fp32_model)\n", |
| 242 | + " bf16_model = tf.keras.models.clone_model(fp32_model)\n", |
244 | 243 | "\n", |
245 | 244 | "fp32_model.summary()" |
246 | 245 | ] |
|
260 | 259 | "metadata": {}, |
261 | 260 | "outputs": [], |
262 | 261 | "source": [ |
263 | | - "class TimeHistory(tf_keras.callbacks.Callback):\n", |
| 262 | + "class TimeHistory(tf.keras.callbacks.Callback):\n", |
264 | 263 | " def on_train_begin(self, logs={}):\n", |
265 | 264 | " self.times = []\n", |
266 | 265 | " self.throughput = []\n", |
|
290 | 289 | "outputs": [], |
291 | 290 | "source": [ |
292 | 291 | "fp32_model.compile(\n", |
293 | | - " optimizer=tf_keras.optimizers.SGD(),\n", |
294 | | - " loss=tf_keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
| 292 | + " optimizer=tf.keras.optimizers.SGD(),\n", |
| 293 | + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
295 | 294 | " metrics=['acc'])" |
296 | 295 | ] |
297 | 296 | }, |
|
374 | 373 | "if arch == 'SPR':\n", |
375 | 374 | " # Compile\n", |
376 | 375 | " bf16_model.compile(\n", |
377 | | - " optimizer=tf_keras.optimizers.SGD(),\n", |
378 | | - " loss=tf_keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
| 376 | + " optimizer=tf.keras.optimizers.SGD(),\n", |
| 377 | + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
379 | 378 | " metrics=['acc'])\n", |
380 | 379 | " \n", |
381 | 380 | " # Train\n", |
|
0 commit comments