Skip to content

Commit c531cd9

Browse files
AdamRajfernv-kkudrynski
authored andcommitted
[ConvNets/PyT] Enable logging gradient scale
1 parent eb35710 commit c531cd9

3 files changed

Lines changed: 13 additions & 2 deletions

File tree

PyTorch/Classification/ConvNets/image_classification/logger.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def __init__(self, logger):
370370
"data_time": ["train.data_time"],
371371
"compute_time": ["train.compute_time"],
372372
"lr": ["train.lr"],
373+
"grad_scale": ["train.grad_scale"],
373374
}
374375
logger.register_metric(
375376
"train.loss",
@@ -406,6 +407,12 @@ def __init__(self, logger):
406407
LR_METER(),
407408
verbosity=dllogger.Verbosity.DEFAULT,
408409
)
410+
logger.register_metric(
411+
"train.grad_scale",
412+
PERF_METER(),
413+
verbosity=dllogger.Verbosity.DEFAULT,
414+
metadata=Metrics.LOSS_METADATA,
415+
)
409416

410417

411418
class ValidationMetrics(Metrics):

PyTorch/Classification/ConvNets/image_classification/training.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def train(
206206
train_step,
207207
train_loader,
208208
lr_scheduler,
209+
grad_scale_fn,
209210
log_fn,
210211
timeout_handler,
211212
prof=-1,
@@ -238,6 +239,7 @@ def train(
238239
compute_time=it_time - data_time,
239240
lr=lr,
240241
loss=reduced_loss.item(),
242+
grad_scale=grad_scale_fn(),
241243
)
242244

243245
end = time.time()
@@ -364,6 +366,7 @@ def train_loop(
364366
training_step,
365367
data_iter,
366368
lambda i: lr_scheduler(trainer.optimizer, i, epoch),
369+
trainer.executor.scaler.get_scale,
367370
train_metrics.log,
368371
timeout_handler,
369372
prof=prof,

PyTorch/Classification/ConvNets/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def _worker_init_fn(id):
422422
print("BSM: {}".format(batch_size_multiplier))
423423

424424
start_epoch = 0
425+
best_prec1 = 0
425426
# optionally resume from a checkpoint
426427
if args.resume is not None:
427428
if os.path.isfile(args.resume):
@@ -609,13 +610,12 @@ def _worker_init_fn(id):
609610
val_loader,
610611
logger,
611612
start_epoch,
613+
best_prec1,
612614
)
613615

614616

615617
def main(args, model_args, model_arch):
616618
exp_start_time = time.time()
617-
global best_prec1
618-
best_prec1 = 0
619619

620620
(
621621
trainer,
@@ -625,6 +625,7 @@ def main(args, model_args, model_arch):
625625
val_loader,
626626
logger,
627627
start_epoch,
628+
best_prec1,
628629
) = prepare_for_training(args, model_args, model_arch)
629630

630631
train_loop(

0 commit comments

Comments
 (0)