1717import horovod .tensorflow as hvd
1818import numpy as np
1919import nvidia .dali .fn as fn
20- import nvidia .dali .math as math
2120import nvidia .dali .ops as ops
2221import nvidia .dali .plugin .tf as dali_tf
2322import nvidia .dali .types as types
@@ -57,7 +56,6 @@ def __init__(
5756 shuffle_input = True ,
5857 input_x_files = None ,
5958 input_y_files = None ,
60- use_cpu = False ,
6159 ):
6260 super ().__init__ (
6361 batch_size = batch_size ,
@@ -85,19 +83,12 @@ def __init__(
8583
8684 self .dim = dim
8785 self .internal_seed = seed
88- self .use_cpu = use_cpu
89-
90- def mark_pipeline_start (self , x , y ):
91- if not self .use_cpu :
92- x , y = x .gpu (), y .gpu ()
93- return x , y
9486
9587
9688class TrainPipeline (GenericPipeline ):
97- def __init__ (self , imgs , lbls , oversampling , patch_size , read_roi = False , batch_size_2d = None , ** kwargs ):
89+ def __init__ (self , imgs , lbls , oversampling , patch_size , batch_size_2d = None , ** kwargs ):
9890 super ().__init__ (input_x_files = imgs , input_y_files = lbls , shuffle_input = True , ** kwargs )
9991 self .oversampling = oversampling
100- self .read_roi = read_roi
10192 self .patch_size = patch_size
10293 if self .dim == 2 and batch_size_2d is not None :
10394 self .patch_size = [batch_size_2d ] + self .patch_size
@@ -129,7 +120,7 @@ def biased_crop_fn(self, img, lbl):
129120 roi_end = roi_end ,
130121 crop_shape = [* self .patch_size , 1 ],
131122 )
132- anchor = fn .slice (anchor , 0 , 3 , axes = [0 ]) # drop channel from anchor
123+ anchor = fn .slice (anchor , 0 , 3 , axes = [0 ])
133124 img , lbl = fn .slice (
134125 [img , lbl ],
135126 anchor ,
@@ -138,40 +129,7 @@ def biased_crop_fn(self, img, lbl):
138129 out_of_bounds_policy = "pad" ,
139130 device = "cpu" ,
140131 )
141-
142- return img .gpu (), lbl .gpu ()
143-
144- def load_roi (self ):
145- lbl = self .input_y (name = "ReaderY" )
146- lbl = fn .reshape (lbl , layout = "DHWC" )
147- roi_start , roi_end = fn .segmentation .random_object_bbox (
148- lbl ,
149- format = "start_end" ,
150- foreground_prob = self .oversampling ,
151- k_largest = 2 ,
152- device = "cpu" ,
153- cache_objects = True ,
154- )
155- anchor = fn .roi_random_crop (lbl , roi_start = roi_start , roi_end = roi_end , crop_shape = [1 , * self .patch_size ])
156- anchor = fn .slice (anchor , 1 , 3 , axes = [0 ]) # drop channel from anchor
157- lbl = fn .slice (
158- lbl ,
159- anchor ,
160- self .crop_shape ,
161- axis_names = "DHW" ,
162- out_of_bounds_policy = "pad" ,
163- device = "cpu" ,
164- )
165-
166- img = self .input_x (
167- name = "ReaderX" ,
168- roi_start = fn .cast (anchor , dtype = types .INT32 ),
169- roi_axes = [1 , 2 , 3 ],
170- roi_shape = self .patch_size ,
171- out_of_bounds_policy = "pad" ,
172- )
173- img = fn .reshape (img , layout = "DHWC" )
174-
132+ img , lbl = img .gpu (), lbl .gpu ()
175133 return img , lbl
176134
177135 def zoom_fn (self , img , lbl ):
@@ -189,22 +147,18 @@ def zoom_fn(self, img, lbl):
189147 return img , lbl
190148
191149 def noise_fn (self , img ):
192- img_noised = img + fn .random . normal (img , stddev = fn .random .uniform (range = (0.0 , 0.33 )))
150+ img_noised = fn .noise . gaussian (img , stddev = fn .random .uniform (range = (0.0 , 0.3 )))
193151 return random_augmentation (0.15 , img_noised , img )
194152
195153 def blur_fn (self , img ):
196154 img_blurred = fn .gaussian_blur (img , sigma = fn .random .uniform (range = (0.5 , 1.5 )))
197155 return random_augmentation (0.15 , img_blurred , img )
198156
199- def brightness_fn (self , img ):
200- brightness_scale = random_augmentation (0.15 , fn .random .uniform (range = (0.7 , 1.3 )), 1.0 )
201- return img * brightness_scale
202-
203- def contrast_fn (self , img ):
204- min_ , max_ = fn .reductions .min (img ), fn .reductions .max (img )
205- scale = random_augmentation (0.15 , fn .random .uniform (range = (0.65 , 1.5 )), 1.0 )
206- img = math .clamp (img * scale , min_ , max_ )
207- return img
157+ def brightness_contrast_fn (self , img ):
158+ img_transformed = fn .brightness_contrast (
159+ img , brightness = fn .random .uniform (range = (0.7 , 1.3 )), contrast = fn .random .uniform (range = (0.65 , 1.5 ))
160+ )
161+ return random_augmentation (0.15 , img_transformed , img )
208162
209163 def flips_fn (self , img , lbl ):
210164 kwargs = {
@@ -216,16 +170,13 @@ def flips_fn(self, img, lbl):
216170 return fn .flip (img , ** kwargs ), fn .flip (lbl , ** kwargs )
217171
218172 def define_graph (self ):
219- if self .read_roi :
220- img , lbl = self .load_roi ()
221- else :
222- img , lbl = self .load_data ()
223- img , lbl = self .biased_crop_fn (img , lbl )
224- img , lbl = img .gpu (), lbl .gpu ()
173+ img , lbl = self .load_data ()
174+ img , lbl = self .biased_crop_fn (img , lbl )
225175 img , lbl = self .zoom_fn (img , lbl )
226176 img , lbl = self .flips_fn (img , lbl )
227- img = self .brightness_fn (img )
228- img = self .contrast_fn (img )
177+ img = self .noise_fn (img )
178+ img = self .blur_fn (img )
179+ img = self .brightness_contrast_fn (img )
229180 return img , lbl
230181
231182
@@ -251,12 +202,11 @@ def define_graph(self):
251202
252203
253204class BenchmarkPipeline (GenericPipeline ):
254- def __init__ (self , imgs , lbls , patch_size , batch_size_2d = None , sw_benchmark = False , ** kwargs ):
205+ def __init__ (self , imgs , lbls , patch_size , batch_size_2d = None , ** kwargs ):
255206 super ().__init__ (input_x_files = imgs , input_y_files = lbls , shuffle_input = False , ** kwargs )
256207 self .patch_size = patch_size
257208 if self .dim == 2 and batch_size_2d is not None :
258209 self .patch_size = [batch_size_2d ] + self .patch_size
259- self .crop = not sw_benchmark
260210
261211 def crop_fn (self , img , lbl ):
262212 img = fn .crop (img , crop = self .patch_size , out_of_bounds_policy = "pad" )
@@ -265,9 +215,8 @@ def crop_fn(self, img, lbl):
265215
266216 def define_graph (self ):
267217 img , lbl = self .input_x (name = "ReaderX" ).gpu (), self .input_y (name = "ReaderY" ).gpu ()
218+ img , lbl = self .crop_fn (img , lbl )
268219 img , lbl = fn .reshape (img , layout = "DHWC" ), fn .reshape (lbl , layout = "DHWC" )
269- if self .crop :
270- img , lbl = self .crop_fn (img , lbl )
271220 return img , lbl
272221
273222
@@ -293,7 +242,6 @@ def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
293242 "batch_size" : batch_size ,
294243 "num_threads" : kwargs ["num_workers" ],
295244 "shard_id" : device_id ,
296- "use_cpu" : kwargs ["use_cpu" ],
297245 }
298246 if kwargs ["dim" ] == 2 :
299247 if kwargs ["benchmark" ]:
@@ -308,13 +256,9 @@ def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
308256
309257 output_dtypes = (tf .float32 , tf .uint8 )
310258 if kwargs ["benchmark" ]:
311- pipeline = BenchmarkPipeline (
312- imgs , lbls , kwargs ["patch_size" ], sw_benchmark = kwargs ["sw_benchmark" ], ** pipe_kwargs
313- )
259+ pipeline = BenchmarkPipeline (imgs , lbls , kwargs ["patch_size" ], ** pipe_kwargs )
314260 elif mode == "train" :
315- pipeline = TrainPipeline (
316- imgs , lbls , kwargs ["oversampling" ], kwargs ["patch_size" ], kwargs ["read_roi" ], ** pipe_kwargs
317- )
261+ pipeline = TrainPipeline (imgs , lbls , kwargs ["oversampling" ], kwargs ["patch_size" ], ** pipe_kwargs )
318262 elif mode == "eval" :
319263 pipeline = EvalPipeline (imgs , lbls , kwargs ["patch_size" ], ** pipe_kwargs )
320264 else :
0 commit comments