@@ -168,16 +168,19 @@ def format_log(loss, split, args):
168168 return log_str
169169
170170
171- def evaluate (eval_iter , model , meters , log_interval , max_size = None , repeat = 1 ):
171+ def evaluate (
172+ eval_iter , model , device , meters , log_interval , max_size = None , repeat = 1
173+ ):
172174 total_len , total_loss = 0 , 0.
173175 eval_step = 0
174176
175177 log_throughput = 0
176178 log_latency = 0
177179 log_loss = 0
178180
179- torch . cuda . synchronize ()
181+ utils . distributed . barrier ()
180182 start_time = time .time ()
183+
181184 with torch .no_grad ():
182185 mems = None
183186 for _ in range (repeat ):
@@ -186,10 +189,12 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
186189 break
187190 eval_step += 1
188191
189- torch . cuda . synchronize ()
192+ utils . distributed . barrier ()
190193 start_iter = time .time ()
194+
191195 loss , mems = model (data , target , mems )
192- torch .cuda .synchronize ()
196+
197+ utils .distributed .barrier ()
193198 elapsed = time .time () - start_iter
194199
195200 loss = loss .float ().mean ()
@@ -204,7 +209,7 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
204209 target_tokens = target .numel ()
205210 throughput = target_tokens / elapsed
206211 throughput = utils .distributed .all_reduce_item (throughput , op = 'sum' )
207- meters ['eval_throughput' ].update (throughput )
212+ meters ['eval_throughput' ].update (throughput , elapsed )
208213 log_throughput += throughput
209214
210215 if eval_step % log_interval == 0 :
@@ -238,8 +243,8 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
238243 log_loss = 0
239244
240245 utils .distributed .barrier ()
241- torch .cuda .synchronize ()
242246 total_time = time .time () - start_time
247+
243248 logging .info ('Time : {:.2f}s, {:.2f}ms/segment' .format (
244249 total_time , 1000 * total_time / (idx + 1 )))
245250
@@ -251,13 +256,18 @@ def evaluate(eval_iter, model, meters, log_interval, max_size=None, repeat=1):
251256def compile_model (model , device , args ):
252257 inp = torch .randint (0 , 1000 , (args .tgt_len , args .batch_size )).to (device )
253258 tgt = torch .randint (0 , 1000 , (args .tgt_len , args .batch_size )).to (device )
259+
260+ utils .distributed .barrier ()
254261 start = time .time ()
262+
255263 with torch .no_grad ():
256264 mems = None
257265 for _ in range (2 ):
258266 _ , mems = model (inp , tgt , mems )
259- torch .cuda .synchronize ()
267+
268+ utils .distributed .barrier ()
260269 stop = time .time ()
270+
261271 logging .info (f'Building the model took { stop - start :.2f} seconds' )
262272
263273
@@ -450,7 +460,7 @@ def main():
450460 meters ['eval_throughput' ] = AverageMeter (warmup = warmup , keep = args .save_data )
451461 meters ['eval_latency' ] = AverageMeter (warmup = warmup , keep = args .save_data )
452462
453- loss = evaluate (iter , model , meters , args .log_interval , args .max_size , args .repeat )
463+ loss = evaluate (iter , model , device , meters , args .log_interval , args .max_size , args .repeat )
454464 perplexity = math .exp (loss )
455465 log_str = format_log (loss , args .split , args )
456466
@@ -476,15 +486,17 @@ def main():
476486 }
477487 with open (data_path , 'wb' ) as f :
478488 pickle .dump (data , f )
479- logging .info (f'Throughput Avg: { throughput_data .mean ():.2f} tok/s' )
489+
490+ avg_throughput = meters ['eval_throughput' ].avg
491+ logging .info (f'Throughput Avg: { avg_throughput :.2f} tok/s' )
480492 logging .info (f'Latency Avg: { 1000.0 * latency_data .mean ():.2f} ms' )
481493 for p in args .percentiles :
482494 logging .info (f'Latency { p } %: { 1000.0 * np .percentile (latency_data , p ):.2f} ms' )
483495
484496 logging .info ('=' * 100 )
485497
486498 summary .update ({
487- 'eval_throughput' : throughput_data . mean () ,
499+ 'eval_throughput' : avg_throughput ,
488500 'eval_avg_latency' : 1000 * latency_data .mean (),
489501 })
490502 for p in args .percentiles :
0 commit comments