Skip to content

Commit 09f9fe6

Browse files
michal2409nv-kkudrynski
authored andcommitted
[nnUNet/TF2] Update container to 22.11, fix XLA+channel last conv, multi-gpu binding script
1 parent b1fc3c4 commit 09f9fe6

16 files changed

Lines changed: 321 additions & 480 deletions

File tree

TensorFlow2/Segmentation/nnUNet/Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:22.04-tf2-py3
1+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/tensorflow:22.11-tf2-py3
22
FROM ${FROM_IMAGE_NAME}
33

44
RUN pip install nvidia-pyindex
@@ -13,6 +13,7 @@ RUN unzip -qq awscliv2.zip
1313
RUN ./aws/install
1414
RUN rm -rf awscliv2.zip aws
1515

16+
ENV OMP_NUM_THREADS=2
1617
ENV TF_CPP_MIN_LOG_LEVEL 3
1718
ENV OMPI_MCA_coll_hcoll_enable 0
1819
ENV HCOLL_ENABLE_MCAST 0

TensorFlow2/Segmentation/nnUNet/README.md

Lines changed: 163 additions & 162 deletions
Large diffs are not rendered by default.

TensorFlow2/Segmentation/nnUNet/data_loading/dali_loader.py

Lines changed: 18 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import horovod.tensorflow as hvd
1818
import numpy as np
1919
import nvidia.dali.fn as fn
20-
import nvidia.dali.math as math
2120
import nvidia.dali.ops as ops
2221
import nvidia.dali.plugin.tf as dali_tf
2322
import nvidia.dali.types as types
@@ -57,7 +56,6 @@ def __init__(
5756
shuffle_input=True,
5857
input_x_files=None,
5958
input_y_files=None,
60-
use_cpu=False,
6159
):
6260
super().__init__(
6361
batch_size=batch_size,
@@ -85,19 +83,12 @@ def __init__(
8583

8684
self.dim = dim
8785
self.internal_seed = seed
88-
self.use_cpu = use_cpu
89-
90-
def mark_pipeline_start(self, x, y):
91-
if not self.use_cpu:
92-
x, y = x.gpu(), y.gpu()
93-
return x, y
9486

9587

9688
class TrainPipeline(GenericPipeline):
97-
def __init__(self, imgs, lbls, oversampling, patch_size, read_roi=False, batch_size_2d=None, **kwargs):
89+
def __init__(self, imgs, lbls, oversampling, patch_size, batch_size_2d=None, **kwargs):
9890
super().__init__(input_x_files=imgs, input_y_files=lbls, shuffle_input=True, **kwargs)
9991
self.oversampling = oversampling
100-
self.read_roi = read_roi
10192
self.patch_size = patch_size
10293
if self.dim == 2 and batch_size_2d is not None:
10394
self.patch_size = [batch_size_2d] + self.patch_size
@@ -129,7 +120,7 @@ def biased_crop_fn(self, img, lbl):
129120
roi_end=roi_end,
130121
crop_shape=[*self.patch_size, 1],
131122
)
132-
anchor = fn.slice(anchor, 0, 3, axes=[0]) # drop channel from anchor
123+
anchor = fn.slice(anchor, 0, 3, axes=[0])
133124
img, lbl = fn.slice(
134125
[img, lbl],
135126
anchor,
@@ -138,40 +129,7 @@ def biased_crop_fn(self, img, lbl):
138129
out_of_bounds_policy="pad",
139130
device="cpu",
140131
)
141-
142-
return img.gpu(), lbl.gpu()
143-
144-
def load_roi(self):
145-
lbl = self.input_y(name="ReaderY")
146-
lbl = fn.reshape(lbl, layout="DHWC")
147-
roi_start, roi_end = fn.segmentation.random_object_bbox(
148-
lbl,
149-
format="start_end",
150-
foreground_prob=self.oversampling,
151-
k_largest=2,
152-
device="cpu",
153-
cache_objects=True,
154-
)
155-
anchor = fn.roi_random_crop(lbl, roi_start=roi_start, roi_end=roi_end, crop_shape=[1, *self.patch_size])
156-
anchor = fn.slice(anchor, 1, 3, axes=[0]) # drop channel from anchor
157-
lbl = fn.slice(
158-
lbl,
159-
anchor,
160-
self.crop_shape,
161-
axis_names="DHW",
162-
out_of_bounds_policy="pad",
163-
device="cpu",
164-
)
165-
166-
img = self.input_x(
167-
name="ReaderX",
168-
roi_start=fn.cast(anchor, dtype=types.INT32),
169-
roi_axes=[1, 2, 3],
170-
roi_shape=self.patch_size,
171-
out_of_bounds_policy="pad",
172-
)
173-
img = fn.reshape(img, layout="DHWC")
174-
132+
img, lbl = img.gpu(), lbl.gpu()
175133
return img, lbl
176134

177135
def zoom_fn(self, img, lbl):
@@ -189,22 +147,18 @@ def zoom_fn(self, img, lbl):
189147
return img, lbl
190148

191149
def noise_fn(self, img):
192-
img_noised = img + fn.random.normal(img, stddev=fn.random.uniform(range=(0.0, 0.33)))
150+
img_noised = fn.noise.gaussian(img, stddev=fn.random.uniform(range=(0.0, 0.3)))
193151
return random_augmentation(0.15, img_noised, img)
194152

195153
def blur_fn(self, img):
196154
img_blurred = fn.gaussian_blur(img, sigma=fn.random.uniform(range=(0.5, 1.5)))
197155
return random_augmentation(0.15, img_blurred, img)
198156

199-
def brightness_fn(self, img):
200-
brightness_scale = random_augmentation(0.15, fn.random.uniform(range=(0.7, 1.3)), 1.0)
201-
return img * brightness_scale
202-
203-
def contrast_fn(self, img):
204-
min_, max_ = fn.reductions.min(img), fn.reductions.max(img)
205-
scale = random_augmentation(0.15, fn.random.uniform(range=(0.65, 1.5)), 1.0)
206-
img = math.clamp(img * scale, min_, max_)
207-
return img
157+
def brightness_contrast_fn(self, img):
158+
img_transformed = fn.brightness_contrast(
159+
img, brightness=fn.random.uniform(range=(0.7, 1.3)), contrast=fn.random.uniform(range=(0.65, 1.5))
160+
)
161+
return random_augmentation(0.15, img_transformed, img)
208162

209163
def flips_fn(self, img, lbl):
210164
kwargs = {
@@ -216,16 +170,13 @@ def flips_fn(self, img, lbl):
216170
return fn.flip(img, **kwargs), fn.flip(lbl, **kwargs)
217171

218172
def define_graph(self):
219-
if self.read_roi:
220-
img, lbl = self.load_roi()
221-
else:
222-
img, lbl = self.load_data()
223-
img, lbl = self.biased_crop_fn(img, lbl)
224-
img, lbl = img.gpu(), lbl.gpu()
173+
img, lbl = self.load_data()
174+
img, lbl = self.biased_crop_fn(img, lbl)
225175
img, lbl = self.zoom_fn(img, lbl)
226176
img, lbl = self.flips_fn(img, lbl)
227-
img = self.brightness_fn(img)
228-
img = self.contrast_fn(img)
177+
img = self.noise_fn(img)
178+
img = self.blur_fn(img)
179+
img = self.brightness_contrast_fn(img)
229180
return img, lbl
230181

231182

@@ -251,12 +202,11 @@ def define_graph(self):
251202

252203

253204
class BenchmarkPipeline(GenericPipeline):
254-
def __init__(self, imgs, lbls, patch_size, batch_size_2d=None, sw_benchmark=False, **kwargs):
205+
def __init__(self, imgs, lbls, patch_size, batch_size_2d=None, **kwargs):
255206
super().__init__(input_x_files=imgs, input_y_files=lbls, shuffle_input=False, **kwargs)
256207
self.patch_size = patch_size
257208
if self.dim == 2 and batch_size_2d is not None:
258209
self.patch_size = [batch_size_2d] + self.patch_size
259-
self.crop = not sw_benchmark
260210

261211
def crop_fn(self, img, lbl):
262212
img = fn.crop(img, crop=self.patch_size, out_of_bounds_policy="pad")
@@ -265,9 +215,8 @@ def crop_fn(self, img, lbl):
265215

266216
def define_graph(self):
267217
img, lbl = self.input_x(name="ReaderX").gpu(), self.input_y(name="ReaderY").gpu()
218+
img, lbl = self.crop_fn(img, lbl)
268219
img, lbl = fn.reshape(img, layout="DHWC"), fn.reshape(lbl, layout="DHWC")
269-
if self.crop:
270-
img, lbl = self.crop_fn(img, lbl)
271220
return img, lbl
272221

273222

@@ -293,7 +242,6 @@ def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
293242
"batch_size": batch_size,
294243
"num_threads": kwargs["num_workers"],
295244
"shard_id": device_id,
296-
"use_cpu": kwargs["use_cpu"],
297245
}
298246
if kwargs["dim"] == 2:
299247
if kwargs["benchmark"]:
@@ -308,13 +256,9 @@ def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
308256

309257
output_dtypes = (tf.float32, tf.uint8)
310258
if kwargs["benchmark"]:
311-
pipeline = BenchmarkPipeline(
312-
imgs, lbls, kwargs["patch_size"], sw_benchmark=kwargs["sw_benchmark"], **pipe_kwargs
313-
)
259+
pipeline = BenchmarkPipeline(imgs, lbls, kwargs["patch_size"], **pipe_kwargs)
314260
elif mode == "train":
315-
pipeline = TrainPipeline(
316-
imgs, lbls, kwargs["oversampling"], kwargs["patch_size"], kwargs["read_roi"], **pipe_kwargs
317-
)
261+
pipeline = TrainPipeline(imgs, lbls, kwargs["oversampling"], kwargs["patch_size"], **pipe_kwargs)
318262
elif mode == "eval":
319263
pipeline = EvalPipeline(imgs, lbls, kwargs["patch_size"], **pipe_kwargs)
320264
else:

TensorFlow2/Segmentation/nnUNet/data_loading/data_module.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,6 @@ def __init__(self, args):
4444
"nvol": self.args.nvol,
4545
"bench_steps": self.args.bench_steps,
4646
"meta": load_data(self.data_path, "*_meta.npy"),
47-
"read_roi": self.args.read_roi,
48-
"use_cpu": self.args.dali_use_cpu,
49-
"sw_benchmark": self.args.sw_benchmark,
5047
}
5148

5249
def setup(self, stage=None):

TensorFlow2/Segmentation/nnUNet/main.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import ctypes
16-
import os
17-
1815
from data_loading.data_module import DataModule
1916
from models.nn_unet import NNUnet
2017
from runtime.args import get_main_args
@@ -25,17 +22,6 @@
2522

2623

2724
def main(args):
28-
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
29-
os.environ["TF_GPU_THREAD_COUNT"] = "1"
30-
31-
_libcudart = ctypes.CDLL("libcudart.so")
32-
# Set device limit on the current device
33-
# cudaLimitMaxL2FetchGranularity = 0x05
34-
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
35-
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
36-
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
37-
assert pValue.contents.value == 128
38-
3925
hvd_init()
4026
if args.seed is not None:
4127
set_seed(args.seed)

TensorFlow2/Segmentation/nnUNet/models/layers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import nv_norms
1516
import tensorflow as tf
1617
import tensorflow_addons as tfa
1718

@@ -26,7 +27,7 @@
2627
class KaimingNormal(tf.keras.initializers.VarianceScaling):
2728
def __init__(self, negative_slope, seed=None):
2829
super().__init__(
29-
scale=2.0 / (1 + negative_slope ** 2), mode="fan_in", distribution="untruncated_normal", seed=seed
30+
scale=2.0 / (1 + negative_slope**2), mode="fan_in", distribution="untruncated_normal", seed=seed
3031
)
3132

3233
def get_config(self):
@@ -38,6 +39,8 @@ def get_norm(name):
3839
return tfa.layers.GroupNormalization(32, axis=-1, center=True, scale=True)
3940
elif "batch" in name:
4041
return tf.keras.layers.BatchNormalization(axis=-1, center=True, scale=True)
42+
elif "atex_instance" in name:
43+
return nv_norms.InstanceNormalization(axis=-1)
4144
elif "instance" in name:
4245
return tfa.layers.InstanceNormalization(axis=-1, center=True, scale=True)
4346
elif "none" in name:

TensorFlow2/Segmentation/nnUNet/models/nn_unet.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from runtime.utils import get_config_file, get_tta_flips, is_main_process
2020
from skimage.transform import resize
2121

22-
from models.sliding_window import sliding_window_inference
22+
from models.sliding_window import get_importance_kernel, sliding_window_inference
2323
from models.unet import UNet
2424

2525

@@ -41,6 +41,8 @@ def wrapped_model(inputs, *args, **kwargs):
4141

4242
self.model = wrapped_model
4343
else:
44+
if not self.args.xla and self.args.norm == "instance":
45+
self.args.norm = "atex_instance"
4446
self.model = UNet(
4547
input_shape=input_shape,
4648
n_class=n_class,
@@ -54,11 +56,28 @@ def wrapped_model(inputs, *args, **kwargs):
5456
if is_main_process():
5557
print(f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}")
5658
self.tta_flips = get_tta_flips(self.args.dim)
59+
if self.args.dim == 3:
60+
self.predictor = self.sw_inference
61+
elif self.args.benchmark:
62+
self.predictor = self.call
63+
else:
64+
self.predictor = self.call_2d
65+
66+
if args.dim == 3:
67+
importance_kernel = get_importance_kernel(self.patch_size, args.blend_mode, 0.125)
68+
self.importance_map = tf.tile(
69+
tf.reshape(importance_kernel, shape=[1, *self.patch_size, 1]),
70+
multiples=[1, 1, 1, 1, n_class],
71+
)
5772

58-
@tf.function(experimental_relax_shapes=True)
73+
@tf.function
5974
def call(self, *args, **kwargs):
6075
return self.model(*args, **kwargs)
6176

77+
@tf.function(reduce_retracing=True)
78+
def call_2d(self, *args, **kwargs):
79+
return self.model(*args, **kwargs)
80+
6281
@tf.function
6382
def compute_loss(self, loss_fn, label, preds):
6483
if self.args.deep_supervision:
@@ -77,21 +96,19 @@ def sw_inference(self, img, **kwargs):
7796
return sliding_window_inference(
7897
inputs=img,
7998
roi_size=self.patch_size,
80-
sw_batch_size=self.args.sw_batch_size,
8199
model=self.model,
82100
overlap=self.args.overlap,
83101
n_class=self.n_class,
84-
blend_mode=self.args.blend_mode,
102+
importance_map=self.importance_map,
85103
**kwargs,
86104
)
87105

88106
def inference(self, img):
89-
predictor = self.call if self.args.dim == 2 else self.sw_inference
90-
pred = predictor(img, training=False)
107+
pred = self.predictor(img, training=False)
91108
if self.args.tta:
92109
for flip_axes in self.tta_flips:
93110
flipped_img = tf.reverse(img, axis=flip_axes)
94-
flipped_pred = predictor(flipped_img, training=False)
111+
flipped_pred = self.predictor(flipped_img, training=False)
95112
pred = pred + tf.reverse(flipped_pred, axis=flip_axes)
96113
pred = pred / (len(self.tta_flips) + 1)
97114
return pred

0 commit comments

Comments
 (0)