Skip to content

Commit bd1fb86

Browse files
committed
Merge: [ConvNets/PyT] Enable logging gradient scale
2 parents 4954233 + c531cd9 commit bd1fb86

3 files changed

Lines changed: 13 additions & 2 deletions

File tree

PyTorch/Classification/ConvNets/image_classification/logger.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def __init__(self, logger):
370370
"data_time": ["train.data_time"],
371371
"compute_time": ["train.compute_time"],
372372
"lr": ["train.lr"],
373+
"grad_scale": ["train.grad_scale"],
373374
}
374375
logger.register_metric(
375376
"train.loss",
@@ -406,6 +407,12 @@ def __init__(self, logger):
406407
LR_METER(),
407408
verbosity=dllogger.Verbosity.DEFAULT,
408409
)
410+
logger.register_metric(
411+
"train.grad_scale",
412+
PERF_METER(),
413+
verbosity=dllogger.Verbosity.DEFAULT,
414+
metadata=Metrics.LOSS_METADATA,
415+
)
409416

410417

411418
class ValidationMetrics(Metrics):

PyTorch/Classification/ConvNets/image_classification/training.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def train(
206206
train_step,
207207
train_loader,
208208
lr_scheduler,
209+
grad_scale_fn,
209210
log_fn,
210211
timeout_handler,
211212
prof=-1,
@@ -238,6 +239,7 @@ def train(
238239
compute_time=it_time - data_time,
239240
lr=lr,
240241
loss=reduced_loss.item(),
242+
grad_scale=grad_scale_fn(),
241243
)
242244

243245
end = time.time()
@@ -364,6 +366,7 @@ def train_loop(
364366
training_step,
365367
data_iter,
366368
lambda i: lr_scheduler(trainer.optimizer, i, epoch),
369+
trainer.executor.scaler.get_scale,
367370
train_metrics.log,
368371
timeout_handler,
369372
prof=prof,

PyTorch/Classification/ConvNets/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def _worker_init_fn(id):
416416
print("BSM: {}".format(batch_size_multiplier))
417417

418418
start_epoch = 0
419+
best_prec1 = 0
419420
# optionally resume from a checkpoint
420421
if args.resume is not None:
421422
if os.path.isfile(args.resume):
@@ -603,13 +604,12 @@ def _worker_init_fn(id):
603604
val_loader,
604605
logger,
605606
start_epoch,
607+
best_prec1,
606608
)
607609

608610

609611
def main(args, model_args, model_arch):
610612
exp_start_time = time.time()
611-
global best_prec1
612-
best_prec1 = 0
613613

614614
(
615615
trainer,
@@ -619,6 +619,7 @@ def main(args, model_args, model_arch):
619619
val_loader,
620620
logger,
621621
start_epoch,
622+
best_prec1,
622623
) = prepare_for_training(args, model_args, model_arch)
623624

624625
train_loop(

0 commit comments

Comments
 (0)