|
14 | 14 |
|
15 | 15 | import os |
16 | 16 |
|
| 17 | +import torch |
17 | 18 | from pytorch_lightning import Trainer, seed_everything |
18 | 19 | from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary, RichProgressBar |
19 | | -from pytorch_lightning.loggers import TensorBoardLogger |
| 20 | +from pytorch_lightning.plugins.io import AsyncCheckpointIO |
| 21 | +from pytorch_lightning.strategies import DDPStrategy |
20 | 22 |
|
21 | 23 | from data_loading.data_module import DataModule |
22 | 24 | from nnunet.nn_unet import NNUnet |
23 | 25 | from utils.args import get_main_args |
24 | 26 | from utils.logger import LoggingCallback |
25 | 27 | from utils.utils import make_empty_dir, set_cuda_devices, set_granularity, verify_ckpt_path |
26 | 28 |
|
27 | | -if __name__ == "__main__": |
| 29 | +torch.backends.cuda.matmul.allow_tf32 = True |
| 30 | +torch.backends.cudnn.allow_tf32 = True |
| 31 | + |
| 32 | + |
| 33 | +def get_trainer(args, callbacks): |
| 34 | + return Trainer( |
| 35 | + logger=False, |
| 36 | + default_root_dir=args.results, |
| 37 | + benchmark=True, |
| 38 | + deterministic=False, |
| 39 | + max_epochs=args.epochs, |
| 40 | + precision=16 if args.amp else 32, |
| 41 | + gradient_clip_val=args.gradient_clip_val, |
| 42 | + enable_checkpointing=args.save_ckpt, |
| 43 | + callbacks=callbacks, |
| 44 | + num_sanity_val_steps=0, |
| 45 | + accelerator="gpu", |
| 46 | + devices=args.gpus, |
| 47 | + num_nodes=args.nodes, |
| 48 | + plugins=[AsyncCheckpointIO()], |
| 49 | + strategy=DDPStrategy( |
| 50 | + find_unused_parameters=False, |
| 51 | + static_graph=True, |
| 52 | + gradient_as_bucket_view=True, |
| 53 | + ), |
| 54 | + limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches, |
| 55 | + limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches, |
| 56 | + limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches, |
| 57 | + ) |
| 58 | + |
| 59 | + |
| 60 | +def main(): |
28 | 61 | args = get_main_args() |
29 | | - set_granularity() # Increase maximum fetch granularity of L2 to 128 bytes |
| 62 | + set_granularity() |
30 | 63 | set_cuda_devices(args) |
31 | 64 | if args.seed is not None: |
32 | 65 | seed_everything(args.seed) |
33 | 66 | data_module = DataModule(args) |
34 | 67 | data_module.setup() |
35 | 68 | ckpt_path = verify_ckpt_path(args) |
36 | 69 |
|
37 | | - model = NNUnet(args) |
| 70 | + if ckpt_path is not None: |
| 71 | + model = NNUnet.load_from_checkpoint(ckpt_path, strict=False, args=args) |
| 72 | + else: |
| 73 | + model = NNUnet(args) |
38 | 74 | callbacks = [RichProgressBar(), ModelSummary(max_depth=2)] |
39 | | - logger = False |
40 | 75 | if args.benchmark: |
41 | 76 | batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size |
42 | 77 | filnename = args.logname if args.logname is not None else "perf.json" |
|
51 | 86 | ) |
52 | 87 | ) |
53 | 88 | elif args.exec_mode == "train": |
54 | | - if args.tb_logs: |
55 | | - logger = TensorBoardLogger( |
56 | | - save_dir=f"{args.results}/tb_logs", |
57 | | - name=f"task={args.task}_dim={args.dim}_fold={args.fold}_precision={16 if args.amp else 32}", |
58 | | - default_hp_metric=False, |
59 | | - version=0, |
60 | | - ) |
61 | 89 | if args.save_ckpt: |
62 | 90 | callbacks.append( |
63 | 91 | ModelCheckpoint( |
|
69 | 97 | ) |
70 | 98 | ) |
71 | 99 |
|
72 | | - trainer = Trainer( |
73 | | - logger=logger, |
74 | | - default_root_dir=args.results, |
75 | | - benchmark=True, |
76 | | - deterministic=False, |
77 | | - max_epochs=args.epochs, |
78 | | - precision=16 if args.amp else 32, |
79 | | - gradient_clip_val=args.gradient_clip_val, |
80 | | - enable_checkpointing=args.save_ckpt, |
81 | | - callbacks=callbacks, |
82 | | - num_sanity_val_steps=0, |
83 | | - accelerator="gpu", |
84 | | - devices=args.gpus, |
85 | | - num_nodes=args.nodes, |
86 | | - strategy="ddp" if args.gpus > 1 else None, |
87 | | - limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches, |
88 | | - limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches, |
89 | | - limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches, |
90 | | - ) |
91 | | - |
| 100 | + trainer = get_trainer(args, callbacks) |
92 | 101 | if args.benchmark: |
93 | 102 | if args.exec_mode == "train": |
94 | 103 | trainer.fit(model, train_dataloaders=data_module.train_dataloader()) |
|
99 | 108 | model.start_benchmark = 1 |
100 | 109 | trainer.test(model, dataloaders=data_module.test_dataloader(), verbose=False) |
101 | 110 | elif args.exec_mode == "train": |
102 | | - trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path) |
| 111 | + trainer.fit(model, datamodule=data_module) |
103 | 112 | elif args.exec_mode == "evaluate": |
104 | 113 | trainer.validate(model, dataloaders=data_module.val_dataloader()) |
105 | 114 | elif args.exec_mode == "predict": |
|
113 | 122 | model.save_dir = save_dir |
114 | 123 | make_empty_dir(save_dir) |
115 | 124 | model.args = args |
116 | | - trainer.test(model, dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path) |
| 125 | + trainer.test(model, dataloaders=data_module.test_dataloader()) |
| 126 | + |
| 127 | + |
| 128 | +if __name__ == "__main__": |
| 129 | + main() |
0 commit comments