@@ -180,6 +180,7 @@ def on_batch_end(self, batch, logs=None):
180180 self .timestamp_log .append (BatchTimestamp (self .global_steps , now ))
181181 elapsed_time_str = '{:.2f} seconds' .format (elapsed_time )
182182 self .logger .log (step = 'PARAMETER' , data = {'Latency' : elapsed_time_str , 'fps' : examples_per_second , 'steps' : (self .last_log_step , self .global_steps )})
183+ self .logger .flush ()
183184
184185 if self .summary_writer :
185186 with self .summary_writer .as_default ():
@@ -371,13 +372,14 @@ def on_epoch_end(self, epoch, logs=None):
371372
372373
373374class COCOEvalCallback (tf .keras .callbacks .Callback ):
374- def __init__ (self , eval_dataset , eval_freq , start_eval_epoch , eval_params , ** kwargs ):
375+ def __init__ (self , eval_dataset , eval_freq , start_eval_epoch , eval_params , logger , ** kwargs ):
375376 super (COCOEvalCallback , self ).__init__ (** kwargs )
376377 self .dataset = eval_dataset
377378 self .eval_freq = eval_freq
378379 self .start_eval_epoch = start_eval_epoch
379380 self .eval_params = eval_params
380381 self .ema_opt = None
382+ self .logger = logger
381383
382384 label_map = label_util .get_label_map (eval_params ['label_map' ])
383385 self .evaluator = coco_metric .EvaluationMetric (
@@ -425,6 +427,8 @@ def evaluate(self, epoch):
425427 csv_metrics = ['AP' ,'AP50' ,'AP75' ,'APs' ,'APm' ,'APl' ]
426428 csv_format = "," .join ([str (epoch + 1 )] + [str (round (metric_dict [key ] * 100 , 2 )) for key in csv_metrics ])
427429 print (metric_dict , "csv format:" , csv_format )
430+ self .logger .log (step = (), data = {'epoch' : epoch + 1 ,
431+ 'validation_accuracy_mAP' : round (metric_dict ['AP' ] * 100 , 2 )})
428432
429433 if self .eval_params ['moving_average_decay' ] > 0 :
430434 self .ema_opt .swap_weights () # get base weights
@@ -492,7 +496,8 @@ def get_callbacks(
492496 cocoeval = COCOEvalCallback (eval_dataset ,
493497 eval_freq = params ['checkpoint_period' ],
494498 start_eval_epoch = 200 ,
495- eval_params = eval_params )
499+ eval_params = eval_params ,
500+ logger = logger )
496501 callbacks .append (cocoeval )
497502
498503 if params ['moving_average_decay' ]:
@@ -504,4 +509,4 @@ def get_callbacks(
504509 os .path .join (params ['model_dir' ], 'train' ))
505510 callbacks .append (display_callback )
506511
507- return callbacks
512+ return callbacks
0 commit comments