Skip to content

Commit 4da7bf2

Browse files
michal2409nv-kkudrynski
authored andcommitted
[nnUNet/PyT] Update container to 22.11, channel last conv, nvFuser InstanceNorm, multi-gpu binding script
1 parent bf00fe1 commit 4da7bf2

14 files changed

Lines changed: 545 additions & 267 deletions

File tree

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.11-py3
1+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.11-py3
22
FROM ${FROM_IMAGE_NAME}
33

44
ADD ./requirements.txt .
55
RUN pip install --disable-pip-version-check -r requirements.txt
6-
RUN pip install monai==0.8.1 --no-dependencies
7-
RUN pip uninstall -y torchtext
6+
RUN pip install monai==1.0.0 --no-dependencies
87
RUN pip install numpy --upgrade
9-
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==1.16.0
108

119
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
1210
RUN unzip -qq awscliv2.zip
1311
RUN ./aws/install
1412
RUN rm -rf awscliv2.zip aws
1513

14+
ENV OMP_NUM_THREADS=2
1615
WORKDIR /workspace/nnunet_pyt
1716
ADD . /workspace/nnunet_pyt
17+
RUN cp utils/instance_norm.py /usr/local/lib/python3.8/dist-packages/apex/normalization

PyTorch/Segmentation/nnUNet/README.md

Lines changed: 170 additions & 145 deletions
Large diffs are not rendered by default.

PyTorch/Segmentation/nnUNet/data_loading/dali_loader.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ def __init__(self, batch_size, num_threads, device_id, **kwargs):
3636
self.kwargs = kwargs
3737
self.dim = kwargs["dim"]
3838
self.device = device_id
39+
self.layout = kwargs["layout"]
3940
self.patch_size = kwargs["patch_size"]
4041
self.load_to_gpu = kwargs["load_to_gpu"]
4142
self.input_x = self.get_reader(kwargs["imgs"])
4243
self.input_y = self.get_reader(kwargs["lbls"]) if kwargs["lbls"] is not None else None
44+
self.cdhw2dhwc = ops.Transpose(device="gpu", perm=[1, 2, 3, 0])
4345

4446
def get_reader(self, data):
4547
return ops.readers.Numpy(
@@ -67,6 +69,10 @@ def load_data(self):
6769
return img, lbl
6870
return img
6971

72+
def make_dhwc_layout(self, img, lbl):
73+
img, lbl = self.cdhw2dhwc(img), self.cdhw2dhwc(lbl)
74+
return img, lbl
75+
7076
def crop(self, data):
7177
return fn.crop(data, crop=self.patch_size, out_of_bounds_policy="pad")
7278

@@ -154,6 +160,8 @@ def define_graph(self):
154160
img = self.contrast_fn(img)
155161
if self.dim == 2:
156162
img, lbl = self.transpose_fn(img, lbl)
163+
if self.layout == "NDHWC" and self.dim == 3:
164+
img, lbl = self.make_dhwc_layout(img, lbl)
157165
return img, lbl
158166

159167

@@ -171,6 +179,8 @@ def define_graph(self):
171179
meta = self.input_meta(name="ReaderM")
172180
orig_lbl = self.input_orig_y(name="ReaderO")
173181
return img, lbl, meta, orig_lbl
182+
if self.layout == "NDHWC" and self.dim == 3:
183+
img, lbl = self.make_dhwc_layout(img, lbl)
174184
return img, lbl
175185

176186

@@ -204,6 +214,8 @@ def define_graph(self):
204214
img, lbl = self.crop_fn(img, lbl)
205215
if self.dim == 2:
206216
img, lbl = self.transpose_fn(img, lbl)
217+
if self.layout == "NDHWC" and self.dim == 3:
218+
img, lbl = self.make_dhwc_layout(img, lbl)
207219
return img, lbl
208220

209221

@@ -250,6 +262,10 @@ def fetch_dali_loader(imgs, lbls, batch_size, mode, **kwargs):
250262
pipe_kwargs.update({"patch_size": [batch_size_2d] + kwargs["patch_size"]})
251263

252264
rank = int(os.getenv("LOCAL_RANK", "0"))
265+
if mode == "eval": # We sharded the data for evaluation manually.
266+
rank = 0
267+
pipe_kwargs["gpus"] = 1
268+
253269
pipe = pipeline(batch_size, kwargs["num_workers"], rank, **pipe_kwargs)
254270
return LightningWrapper(
255271
pipe,

PyTorch/Segmentation/nnUNet/data_loading/data_module.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(self, args):
3434
"seed": self.args.seed,
3535
"gpus": self.args.gpus,
3636
"nvol": self.args.nvol,
37+
"layout": self.args.layout,
3738
"overlap": self.args.overlap,
3839
"benchmark": self.args.benchmark,
3940
"num_workers": self.args.num_workers,
@@ -57,6 +58,11 @@ def setup(self, stage=None):
5758
self.kwargs.update({"orig_lbl": orig_lbl, "meta": meta})
5859
self.train_imgs, self.train_lbls = get_split(imgs, train_idx), get_split(lbls, train_idx)
5960
self.val_imgs, self.val_lbls = get_split(imgs, val_idx), get_split(lbls, val_idx)
61+
62+
if self.args.gpus > 1:
63+
rank = int(os.getenv("LOCAL_RANK", "0"))
64+
self.val_imgs = self.val_imgs[rank :: self.args.gpus]
65+
self.val_lbls = self.val_lbls[rank :: self.args.gpus]
6066
else:
6167
self.kwargs.update({"meta": test_meta})
6268
print0(f"{len(self.train_imgs)} training, {len(self.val_imgs)} validation, {len(self.test_imgs)} test examples")

PyTorch/Segmentation/nnUNet/main.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,64 @@
1414

1515
import os
1616

17+
import torch
1718
from pytorch_lightning import Trainer, seed_everything
1819
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary, RichProgressBar
19-
from pytorch_lightning.loggers import TensorBoardLogger
20+
from pytorch_lightning.plugins.io import AsyncCheckpointIO
21+
from pytorch_lightning.strategies import DDPStrategy
2022

2123
from data_loading.data_module import DataModule
2224
from nnunet.nn_unet import NNUnet
2325
from utils.args import get_main_args
2426
from utils.logger import LoggingCallback
2527
from utils.utils import make_empty_dir, set_cuda_devices, set_granularity, verify_ckpt_path
2628

27-
if __name__ == "__main__":
29+
torch.backends.cuda.matmul.allow_tf32 = True
30+
torch.backends.cudnn.allow_tf32 = True
31+
32+
33+
def get_trainer(args, callbacks):
34+
return Trainer(
35+
logger=False,
36+
default_root_dir=args.results,
37+
benchmark=True,
38+
deterministic=False,
39+
max_epochs=args.epochs,
40+
precision=16 if args.amp else 32,
41+
gradient_clip_val=args.gradient_clip_val,
42+
enable_checkpointing=args.save_ckpt,
43+
callbacks=callbacks,
44+
num_sanity_val_steps=0,
45+
accelerator="gpu",
46+
devices=args.gpus,
47+
num_nodes=args.nodes,
48+
plugins=[AsyncCheckpointIO()],
49+
strategy=DDPStrategy(
50+
find_unused_parameters=False,
51+
static_graph=True,
52+
gradient_as_bucket_view=True,
53+
),
54+
limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
55+
limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
56+
limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
57+
)
58+
59+
60+
def main():
2861
args = get_main_args()
29-
set_granularity() # Increase maximum fetch granularity of L2 to 128 bytes
62+
set_granularity()
3063
set_cuda_devices(args)
3164
if args.seed is not None:
3265
seed_everything(args.seed)
3366
data_module = DataModule(args)
3467
data_module.setup()
3568
ckpt_path = verify_ckpt_path(args)
3669

37-
model = NNUnet(args)
70+
if ckpt_path is not None:
71+
model = NNUnet.load_from_checkpoint(ckpt_path, strict=False, args=args)
72+
else:
73+
model = NNUnet(args)
3874
callbacks = [RichProgressBar(), ModelSummary(max_depth=2)]
39-
logger = False
4075
if args.benchmark:
4176
batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size
4277
filnename = args.logname if args.logname is not None else "perf.json"
@@ -51,13 +86,6 @@
5186
)
5287
)
5388
elif args.exec_mode == "train":
54-
if args.tb_logs:
55-
logger = TensorBoardLogger(
56-
save_dir=f"{args.results}/tb_logs",
57-
name=f"task={args.task}_dim={args.dim}_fold={args.fold}_precision={16 if args.amp else 32}",
58-
default_hp_metric=False,
59-
version=0,
60-
)
6189
if args.save_ckpt:
6290
callbacks.append(
6391
ModelCheckpoint(
@@ -69,26 +97,7 @@
6997
)
7098
)
7199

72-
trainer = Trainer(
73-
logger=logger,
74-
default_root_dir=args.results,
75-
benchmark=True,
76-
deterministic=False,
77-
max_epochs=args.epochs,
78-
precision=16 if args.amp else 32,
79-
gradient_clip_val=args.gradient_clip_val,
80-
enable_checkpointing=args.save_ckpt,
81-
callbacks=callbacks,
82-
num_sanity_val_steps=0,
83-
accelerator="gpu",
84-
devices=args.gpus,
85-
num_nodes=args.nodes,
86-
strategy="ddp" if args.gpus > 1 else None,
87-
limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
88-
limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
89-
limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
90-
)
91-
100+
trainer = get_trainer(args, callbacks)
92101
if args.benchmark:
93102
if args.exec_mode == "train":
94103
trainer.fit(model, train_dataloaders=data_module.train_dataloader())
@@ -99,7 +108,7 @@
99108
model.start_benchmark = 1
100109
trainer.test(model, dataloaders=data_module.test_dataloader(), verbose=False)
101110
elif args.exec_mode == "train":
102-
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
111+
trainer.fit(model, datamodule=data_module)
103112
elif args.exec_mode == "evaluate":
104113
trainer.validate(model, dataloaders=data_module.val_dataloader())
105114
elif args.exec_mode == "predict":
@@ -113,4 +122,8 @@
113122
model.save_dir = save_dir
114123
make_empty_dir(save_dir)
115124
model.args = args
116-
trainer.test(model, dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path)
125+
trainer.test(model, dataloaders=data_module.test_dataloader())
126+
127+
128+
if __name__ == "__main__":
129+
main()

PyTorch/Segmentation/nnUNet/nnunet/metrics.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,45 +13,63 @@
1313
# limitations under the License.
1414

1515
import torch
16-
from monai.metrics import compute_meandice, do_metric_reduction
17-
from monai.networks.utils import one_hot
1816
from torchmetrics import Metric
1917

2018

2119
class Dice(Metric):
20+
full_state_update = False
21+
2222
def __init__(self, n_class, brats):
2323
super().__init__(dist_sync_on_step=False)
2424
self.n_class = n_class
2525
self.brats = brats
26-
self.add_state("loss", default=torch.zeros(1), dist_reduce_fx="sum")
2726
self.add_state("steps", default=torch.zeros(1), dist_reduce_fx="sum")
2827
self.add_state("dice", default=torch.zeros((n_class,)), dist_reduce_fx="sum")
28+
self.add_state("loss", default=torch.zeros(1), dist_reduce_fx="sum")
2929

3030
def update(self, p, y, l):
31-
if self.brats:
32-
p = (torch.sigmoid(p) > 0.5).int()
33-
y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
34-
y = torch.stack([y_wt, y_tc, y_et], dim=1)
35-
else:
36-
p, y = self.ohe(torch.argmax(p, dim=1)), self.ohe(y)
37-
3831
self.steps += 1
32+
self.dice += self.compute_stats_brats(p, y) if self.brats else self.compute_stats(p, y)
3933
self.loss += l
40-
self.dice += self.compute_metric(p, y, compute_meandice, 1, 0)
4134

4235
def compute(self):
4336
return 100 * self.dice / self.steps, self.loss / self.steps
4437

45-
def ohe(self, x):
46-
return one_hot(x.unsqueeze(1), num_classes=self.n_class + 1, dim=1)
47-
48-
def compute_metric(self, p, y, metric_fn, best_metric, worst_metric):
49-
metric = metric_fn(p, y, include_background=self.brats)
50-
metric = torch.nan_to_num(metric, nan=worst_metric, posinf=worst_metric, neginf=worst_metric)
51-
metric = do_metric_reduction(metric, "mean_batch")[0]
38+
def compute_stats_brats(self, p, y):
39+
scores = torch.zeros(self.n_class, device=p.device, dtype=torch.float32)
40+
p = (torch.sigmoid(p) > 0.5).int()
41+
y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
42+
y = torch.stack([y_wt, y_tc, y_et], dim=1)
5243

5344
for i in range(self.n_class):
54-
if (y[:, i] != 1).all():
55-
metric[i - 1] += best_metric if (p[:, i] != 1).all() else worst_metric
45+
p_i, y_i = p[:, i], y[:, i]
46+
if (y_i != 1).all():
47+
# no foreground class
48+
scores[i - 1] += 1 if (p_i != 1).all() else 0
49+
continue
50+
tp, fn, fp = self.get_stats(p_i, y_i, 1)
51+
denom = (2 * tp + fp + fn).to(torch.float)
52+
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
53+
scores[i - 1] += score_cls
54+
return scores
55+
56+
def compute_stats(self, p, y):
57+
scores = torch.zeros(self.n_class, device=p.device, dtype=torch.float32)
58+
p = torch.argmax(p, dim=1)
59+
for i in range(1, self.n_class + 1):
60+
if (y != i).all():
61+
# no foreground class
62+
scores[i - 1] += 1 if (p != i).all() else 0
63+
continue
64+
tp, fn, fp = self.get_stats(p, y, i)
65+
denom = (2 * tp + fp + fn).to(torch.float)
66+
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
67+
scores[i - 1] += score_cls
68+
return scores
5669

57-
return metric
70+
@staticmethod
71+
def get_stats(p, y, c):
72+
tp = torch.logical_and(p == c, y == c).sum()
73+
fn = torch.logical_and(p != c, y == c).sum()
74+
fp = torch.logical_and(p == c, y != c).sum()
75+
return tp, fn, fp

0 commit comments

Comments
 (0)