Skip to content

Commit 60c7d97

Browse files
michal2409nv-kkudrynski
authored andcommitted
[nnUNet/PyT] Add BraTS22 notebook
1 parent a2f02eb commit 60c7d97

6 files changed

Lines changed: 970 additions & 17 deletions

File tree

60.1 KB
Loading

PyTorch/Segmentation/nnUNet/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
elif args.exec_mode == "train":
102102
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
103103
elif args.exec_mode == "evaluate":
104-
trainer.validate(model, val_dataloaders=data_module.val_dataloader())
104+
trainer.validate(model, dataloaders=data_module.val_dataloader())
105105
elif args.exec_mode == "predict":
106106
if args.save_preds:
107107
ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1])
@@ -113,4 +113,4 @@
113113
model.save_dir = save_dir
114114
make_empty_dir(save_dir)
115115
model.args = args
116-
trainer.test(model, test_dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path)
116+
trainer.test(model, dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path)
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import torch
17+
import torch.nn as nn
18+
19+
normalizations = {
20+
"instancenorm3d": nn.InstanceNorm3d,
21+
"instancenorm2d": nn.InstanceNorm2d,
22+
"batchnorm3d": nn.BatchNorm3d,
23+
"batchnorm2d": nn.BatchNorm2d,
24+
}
25+
26+
convolutions = {
27+
"Conv2d": nn.Conv2d,
28+
"Conv3d": nn.Conv3d,
29+
"ConvTranspose2d": nn.ConvTranspose2d,
30+
"ConvTranspose3d": nn.ConvTranspose3d,
31+
}
32+
33+
34+
def get_norm(name, out_channels, groups=32):
35+
if "groupnorm" in name:
36+
return nn.GroupNorm(groups, out_channels, affine=True)
37+
return normalizations[name](out_channels, affine=True)
38+
39+
40+
def get_conv(in_channels, out_channels, kernel_size, stride, dim=3, bias=False):
41+
conv = convolutions[f"Conv{dim}d"]
42+
padding = get_padding(kernel_size, stride)
43+
return conv(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
44+
45+
46+
def get_transp_conv(in_channels, out_channels, kernel_size, stride, dim):
47+
conv = convolutions[f"ConvTranspose{dim}d"]
48+
padding = get_padding(kernel_size, stride)
49+
output_padding = get_output_padding(kernel_size, stride, padding)
50+
return conv(in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=True)
51+
52+
53+
def get_padding(kernel_size, stride):
54+
kernel_size_np = np.atleast_1d(kernel_size)
55+
stride_np = np.atleast_1d(stride)
56+
padding_np = (kernel_size_np - stride_np + 1) / 2
57+
padding = tuple(int(p) for p in padding_np)
58+
return padding if len(padding) > 1 else padding[0]
59+
60+
61+
def get_output_padding(kernel_size, stride, padding):
62+
kernel_size_np = np.atleast_1d(kernel_size)
63+
stride_np = np.atleast_1d(stride)
64+
padding_np = np.atleast_1d(padding)
65+
out_padding_np = 2 * padding_np + stride_np - kernel_size_np
66+
out_padding = tuple(int(p) for p in out_padding_np)
67+
return out_padding if len(out_padding) > 1 else out_padding[0]
68+
69+
70+
class InputBlock(nn.Module):
71+
def __init__(self, in_channels, out_channels, **kwargs):
72+
super(InputBlock, self).__init__()
73+
self.conv1 = get_conv(in_channels, out_channels, 3, 1)
74+
self.conv2 = get_conv(out_channels, out_channels, 3, 1)
75+
self.norm = get_norm(kwargs["norm"], out_channels)
76+
self.relu = nn.ReLU(inplace=True)
77+
78+
def forward(self, x):
79+
x = self.conv1(x)
80+
x = self.norm(x)
81+
x = self.relu(x)
82+
x = self.conv2(x)
83+
x = self.relu(x)
84+
return x
85+
86+
87+
class ConvLayer(nn.Module):
88+
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
89+
super(ConvLayer, self).__init__()
90+
self.conv = get_conv(in_channels, out_channels, kernel_size, stride)
91+
self.norm = get_norm(kwargs["norm"], in_channels)
92+
self.relu = nn.ReLU(inplace=True)
93+
94+
def forward(self, x):
95+
x = self.norm(x)
96+
x = self.conv(x)
97+
x = self.relu(x)
98+
return x
99+
100+
101+
class ConvBlock(nn.Module):
102+
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
103+
super(ConvBlock, self).__init__()
104+
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, stride, **kwargs)
105+
self.conv2 = ConvLayer(out_channels, out_channels, kernel_size, 1, **kwargs)
106+
107+
def forward(self, x):
108+
x = self.conv1(x)
109+
x = self.conv2(x)
110+
return x
111+
112+
113+
class UpsampleBlock(nn.Module):
114+
def __init__(self, in_channels, out_channels, kernel_size, stride, **kwargs):
115+
super(UpsampleBlock, self).__init__()
116+
self.conv_block = ConvBlock(out_channels + in_channels, out_channels, kernel_size, 1, **kwargs)
117+
118+
def forward(self, x, x_skip):
119+
x = nn.functional.interpolate(x, scale_factor=2, mode="trilinear", align_corners=True)
120+
x = torch.cat((x, x_skip), dim=1)
121+
x = self.conv_block(x)
122+
return x
123+
124+
125+
class OutputBlock(nn.Module):
126+
def __init__(self, in_channels, out_channels, dim):
127+
super(OutputBlock, self).__init__()
128+
self.conv = get_conv(in_channels, out_channels, kernel_size=1, stride=1, dim=dim, bias=True)
129+
130+
def forward(self, input_data):
131+
return self.conv(input_data)
132+
133+
134+
class UNet3D(nn.Module):
135+
def __init__(
136+
self,
137+
kernels,
138+
strides,
139+
):
140+
super(UNet3D, self).__init__()
141+
self.dim = 3
142+
self.n_class = 3
143+
self.deep_supervision = True
144+
self.norm = "instancenorm3d"
145+
self.filters = [64, 128, 256, 512, 768, 1024, 2048][: len(strides)]
146+
147+
down_block = ConvBlock
148+
self.input_block = InputBlock(5, self.filters[0], norm=self.norm)
149+
self.downsamples = self.get_module_list(
150+
conv_block=down_block,
151+
in_channels=self.filters[:-1],
152+
out_channels=self.filters[1:],
153+
kernels=kernels[1:-1],
154+
strides=strides[1:-1],
155+
)
156+
self.bottleneck = self.get_conv_block(
157+
conv_block=down_block,
158+
in_channels=self.filters[-2],
159+
out_channels=self.filters[-1],
160+
kernel_size=kernels[-1],
161+
stride=strides[-1],
162+
)
163+
self.upsamples = self.get_module_list(
164+
conv_block=UpsampleBlock,
165+
in_channels=self.filters[1:][::-1],
166+
out_channels=self.filters[:-1][::-1],
167+
kernels=kernels[1:][::-1],
168+
strides=strides[1:][::-1],
169+
)
170+
self.output_block = self.get_output_block(decoder_level=0)
171+
self.deep_supervision_heads = self.get_deep_supervision_heads()
172+
self.apply(self.initialize_weights)
173+
174+
def forward(self, input_data):
175+
out = self.input_block(input_data)
176+
encoder_outputs = [out]
177+
for downsample in self.downsamples:
178+
out = downsample(out)
179+
encoder_outputs.append(out)
180+
out = self.bottleneck(out)
181+
decoder_outputs = []
182+
for upsample, skip in zip(self.upsamples, reversed(encoder_outputs)):
183+
out = upsample(out, skip)
184+
decoder_outputs.append(out)
185+
out = self.output_block(out)
186+
if self.training and self.deep_supervision:
187+
out = [out]
188+
for i, decoder_out in enumerate(decoder_outputs[-3:-1][::-1]):
189+
out.append(self.deep_supervision_heads[i](decoder_out))
190+
return out
191+
192+
def get_conv_block(self, conv_block, in_channels, out_channels, kernel_size, stride, drop_block=False):
193+
return conv_block(
194+
dim=self.dim,
195+
stride=stride,
196+
norm=self.norm,
197+
kernel_size=kernel_size,
198+
in_channels=in_channels,
199+
out_channels=out_channels,
200+
)
201+
202+
def get_output_block(self, decoder_level):
203+
return OutputBlock(in_channels=self.filters[decoder_level], out_channels=self.n_class, dim=self.dim)
204+
205+
def get_deep_supervision_heads(self):
206+
return nn.ModuleList([self.get_output_block(1), self.get_output_block(2)])
207+
208+
def get_module_list(self, in_channels, out_channels, kernels, strides, conv_block):
209+
layers = []
210+
for in_channel, out_channel, kernel, stride in zip(in_channels, out_channels, kernels, strides):
211+
conv_layer = self.get_conv_block(conv_block, in_channel, out_channel, kernel, stride)
212+
layers.append(conv_layer)
213+
return nn.ModuleList(layers)
214+
215+
def initialize_weights(self, module):
216+
name = module.__class__.__name__.lower()
217+
if name in ["conv2d", "conv3d"]:
218+
nn.init.kaiming_normal_(module.weight)
219+
if hasattr(module, "bias") and module.bias is not None:
220+
nn.init.constant_(module.bias, 0)

PyTorch/Segmentation/nnUNet/nnunet/nn_unet.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from utils.logger import DLLogger
3030
from utils.utils import get_config_file, print0
3131

32+
from nnunet.brats22_model import UNet3D
3233
from nnunet.loss import Loss, LossBraTS
3334
from nnunet.metrics import Dice
3435

@@ -69,6 +70,14 @@ def _forward(self, img):
6970
return self.tta_inference(img) if self.args.tta else self.do_inference(img)
7071

7172
def compute_loss(self, preds, label):
73+
if self.args.brats22_model:
74+
loss = self.loss(preds[0], label)
75+
for i, pred in enumerate(preds[1:]):
76+
downsampled_label = nn.functional.interpolate(label, pred.shape[2:])
77+
loss += 0.5 ** (i + 1) * self.loss(pred, downsampled_label)
78+
c_norm = 1 / (2 - 2 ** (-len(preds)))
79+
return c_norm * loss
80+
7281
if self.args.deep_supervision:
7382
loss, weights = 0.0, 0.0
7483
for i in range(preds.shape[1]):
@@ -152,21 +161,24 @@ def build_nnunet(self):
152161
if self.args.brats:
153162
out_channels = 3
154163

155-
self.model = DynUNet(
156-
self.args.dim,
157-
in_channels,
158-
out_channels,
159-
kernels,
160-
strides,
161-
strides[1:],
162-
filters=self.args.filters,
163-
norm_name=("INSTANCE", {"affine": True}),
164-
act_name=("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
165-
deep_supervision=self.args.deep_supervision,
166-
deep_supr_num=self.args.deep_supr_num,
167-
res_block=self.args.res_block,
168-
trans_bias=True,
169-
)
164+
if self.args.brats22_model:
165+
self.model = UNet3D(kernels, strides)
166+
else:
167+
self.model = DynUNet(
168+
self.args.dim,
169+
in_channels,
170+
out_channels,
171+
kernels,
172+
strides,
173+
strides[1:],
174+
filters=self.args.filters,
175+
norm_name=("INSTANCE", {"affine": True}),
176+
act_name=("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
177+
deep_supervision=self.args.deep_supervision,
178+
deep_supr_num=self.args.deep_supr_num,
179+
res_block=self.args.res_block,
180+
trans_bias=True,
181+
)
170182
print0(f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}")
171183

172184
def do_inference(self, image):

PyTorch/Segmentation/nnUNet/notebooks/BraTS22.ipynb

Lines changed: 720 additions & 0 deletions
Large diffs are not rendered by default.

PyTorch/Segmentation/nnUNet/utils/args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def get_main_args(strings=None):
5757
arg("--tta", action="store_true", help="Enable test time augmentation")
5858
arg("--tb_logs", action="store_true", help="Log metrics to tensoboard")
5959
arg("--brats", action="store_true", help="Enable BraTS specific training and inference")
60+
arg("--brats22_model", action="store_true", help="Use BraTS22 model")
6061
arg("--deep_supervision", action="store_true", help="Enable deep supervision")
6162
arg("--more_chn", action="store_true", help="Create encoder with more channels")
6263
arg("--invert_resampled_y", action="store_true", help="Resize predictions to match label size before resampling")

0 commit comments

Comments
 (0)