Skip to content

Commit 0915477

Browse files
committed
Merge: [NCF/PyT] Fix time measurement bugs
2 parents 0522471 + 77b6eab commit 0915477

1 file changed

Lines changed: 14 additions & 10 deletions

File tree

  • PyTorch/Recommendation/NCF

PyTorch/Recommendation/NCF/ncf.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
from apex import amp
5252

5353

54+
def synchronized_timestamp():
55+
torch.cuda.synchronize()
56+
return time.time()
57+
5458
def parse_args():
5559
parser = ArgumentParser(description="Train a Neural Collaborative"
5660
" Filtering model")
@@ -218,7 +222,7 @@ def main():
218222
torch.distributed.broadcast(torch.tensor([1], device="cuda"), 0)
219223
torch.cuda.synchronize()
220224

221-
main_start_time = time.time()
225+
main_start_time = synchronized_timestamp()
222226

223227
feature_spec_path = os.path.join(args.data, args.feature_spec_file)
224228
feature_spec = FeatureSpec.from_yaml(feature_spec_path)
@@ -268,10 +272,10 @@ def main():
268272
model.load_state_dict(state_dict)
269273

270274
if args.mode == 'test':
271-
start = time.time()
275+
start = synchronized_timestamp()
272276
hr, ndcg, val_loss = val_epoch(model, test_loader, args.topk,
273277
distributed=args.distributed, world_size=args.world_size)
274-
val_time = time.time() - start
278+
val_time = synchronized_timestamp() - start
275279
eval_size = test_loader.raw_dataset_length
276280
eval_throughput = eval_size / val_time
277281

@@ -285,12 +289,12 @@ def main():
285289
# to an uninitialized variable.
286290
max_hr = 0
287291
best_epoch = 0
288-
best_model_timestamp = time.time()
292+
best_model_timestamp = synchronized_timestamp()
289293
train_throughputs, eval_throughputs = [], []
290294

291295
for epoch in range(args.epochs):
292296

293-
begin = time.time()
297+
begin = synchronized_timestamp()
294298
batch_dict_list = train_loader.get_epoch_data()
295299
num_batches = len(batch_dict_list)
296300
for i in range(num_batches // args.grads_accumulated):
@@ -322,8 +326,8 @@ def main():
322326
p.grad = None
323327

324328
del batch_dict_list
325-
train_time = time.time() - begin
326-
begin = time.time()
329+
train_time = synchronized_timestamp() - begin
330+
begin = synchronized_timestamp()
327331

328332
epoch_samples = train_loader.length_after_augmentation
329333
train_throughput = epoch_samples / train_time
@@ -332,7 +336,7 @@ def main():
332336
hr, ndcg, val_loss = val_epoch(model, test_loader, args.topk,
333337
distributed=args.distributed, world_size=args.world_size)
334338

335-
val_time = time.time() - begin
339+
val_time = synchronized_timestamp() - begin
336340
eval_size = test_loader.raw_dataset_length
337341
eval_throughput = eval_size / val_time
338342
eval_throughputs.append(eval_throughput)
@@ -358,7 +362,7 @@ def main():
358362
save_checkpoint_path = os.path.join(args.checkpoint_dir, 'model.pth')
359363
print("Saving the model to: ", save_checkpoint_path)
360364
torch.save(model.state_dict(), save_checkpoint_path)
361-
best_model_timestamp = time.time()
365+
best_model_timestamp = synchronized_timestamp()
362366

363367
if args.threshold is not None:
364368
if hr >= args.threshold:
@@ -372,7 +376,7 @@ def main():
372376
'mean_eval_throughput': np.mean(eval_throughputs),
373377
'best_accuracy': max_hr,
374378
'best_epoch': best_epoch,
375-
'time_to_target': time.time() - main_start_time,
379+
'time_to_target': synchronized_timestamp() - main_start_time,
376380
'time_to_best_model': best_model_timestamp - main_start_time,
377381
'validation_loss': float(val_loss.item()),
378382
'train_loss': float(loss.item())},

0 commit comments

Comments
 (0)