|
22 | 22 | from data_loading.data_module import get_data_path, get_test_fnames |
23 | 23 | from monai.inferers import sliding_window_inference |
24 | 24 | from monai.networks.nets import DynUNet |
| 25 | +from nnunet.brats22_model import UNet3D |
| 26 | +from nnunet.loss import Loss, LossBraTS |
| 27 | +from nnunet.metrics import Dice |
25 | 28 | from pytorch_lightning.utilities import rank_zero_only |
26 | 29 | from scipy.special import expit, softmax |
27 | 30 | from skimage.transform import resize |
28 | 31 | from utils.logger import DLLogger |
29 | 32 | from utils.utils import get_config_file, print0 |
30 | 33 |
|
31 | | -from nnunet.brats22_model import UNet3D |
32 | | -from nnunet.loss import Loss, LossBraTS |
33 | | -from nnunet.metrics import Dice |
34 | | - |
35 | 34 |
|
36 | 35 | class NNUnet(pl.LightningModule): |
37 | 36 | def __init__(self, args, triton=False, data_dir=None): |
@@ -279,7 +278,7 @@ def test_epoch_end(self, outputs): |
279 | 278 |
|
280 | 279 | @rank_zero_only |
281 | 280 | def on_fit_end(self): |
282 | | - if not self.args.benchmark and self.args.skip_first_n_eval == 0: |
| 281 | + if not self.args.benchmark: |
283 | 282 | metrics = {} |
284 | 283 | metrics["dice_score"] = round(self.best_mean.item(), 2) |
285 | 284 | metrics["train_loss"] = round(sum(self.train_loss) / len(self.train_loss), 4) |
|
0 commit comments