5151from apex import amp
5252
5353
54+ def synchronized_timestamp ():
55+ torch .cuda .synchronize ()
56+ return time .time ()
57+
5458def parse_args ():
5559 parser = ArgumentParser (description = "Train a Neural Collaborative"
5660 " Filtering model" )
@@ -218,7 +222,7 @@ def main():
218222 torch .distributed .broadcast (torch .tensor ([1 ], device = "cuda" ), 0 )
219223 torch .cuda .synchronize ()
220224
221- main_start_time = time . time ()
225+ main_start_time = synchronized_timestamp ()
222226
223227 feature_spec_path = os .path .join (args .data , args .feature_spec_file )
224228 feature_spec = FeatureSpec .from_yaml (feature_spec_path )
@@ -268,10 +272,10 @@ def main():
268272 model .load_state_dict (state_dict )
269273
270274 if args .mode == 'test' :
271- start = time . time ()
275+ start = synchronized_timestamp ()
272276 hr , ndcg , val_loss = val_epoch (model , test_loader , args .topk ,
273277 distributed = args .distributed , world_size = args .world_size )
274- val_time = time . time () - start
278+ val_time = synchronized_timestamp () - start
275279 eval_size = test_loader .raw_dataset_length
276280 eval_throughput = eval_size / val_time
277281
@@ -285,12 +289,12 @@ def main():
285289 # to an uninitialized variable.
286290 max_hr = 0
287291 best_epoch = 0
288- best_model_timestamp = time . time ()
292+ best_model_timestamp = synchronized_timestamp ()
289293 train_throughputs , eval_throughputs = [], []
290294
291295 for epoch in range (args .epochs ):
292296
293- begin = time . time ()
297+ begin = synchronized_timestamp ()
294298 batch_dict_list = train_loader .get_epoch_data ()
295299 num_batches = len (batch_dict_list )
296300 for i in range (num_batches // args .grads_accumulated ):
@@ -322,8 +326,8 @@ def main():
322326 p .grad = None
323327
324328 del batch_dict_list
325- train_time = time . time () - begin
326- begin = time . time ()
329+ train_time = synchronized_timestamp () - begin
330+ begin = synchronized_timestamp ()
327331
328332 epoch_samples = train_loader .length_after_augmentation
329333 train_throughput = epoch_samples / train_time
@@ -332,7 +336,7 @@ def main():
332336 hr , ndcg , val_loss = val_epoch (model , test_loader , args .topk ,
333337 distributed = args .distributed , world_size = args .world_size )
334338
335- val_time = time . time () - begin
339+ val_time = synchronized_timestamp () - begin
336340 eval_size = test_loader .raw_dataset_length
337341 eval_throughput = eval_size / val_time
338342 eval_throughputs .append (eval_throughput )
@@ -358,7 +362,7 @@ def main():
358362 save_checkpoint_path = os .path .join (args .checkpoint_dir , 'model.pth' )
359363 print ("Saving the model to: " , save_checkpoint_path )
360364 torch .save (model .state_dict (), save_checkpoint_path )
361- best_model_timestamp = time . time ()
365+ best_model_timestamp = synchronized_timestamp ()
362366
363367 if args .threshold is not None :
364368 if hr >= args .threshold :
@@ -372,7 +376,7 @@ def main():
372376 'mean_eval_throughput' : np .mean (eval_throughputs ),
373377 'best_accuracy' : max_hr ,
374378 'best_epoch' : best_epoch ,
375- 'time_to_target' : time . time () - main_start_time ,
379+ 'time_to_target' : synchronized_timestamp () - main_start_time ,
376380 'time_to_best_model' : best_model_timestamp - main_start_time ,
377381 'validation_loss' : float (val_loss .item ()),
378382 'train_loss' : float (loss .item ())},
0 commit comments