@@ -513,32 +513,27 @@ def _recompute_all_losses(self):
513513 @property
514514 def _scale (self ):
515515 # get the output scale
516- scale = self ._max_value - self ._min_value
517- if isinstance (scale , np .ndarray ):
518- scale [scale == 0 ] = 1
519- elif scale == 0 :
520- scale = 1
521- return scale
516+ return np .max (self ._max_value - self ._min_value )
522517
523518 def _update_range (self , new_output ):
524519 if self ._min_value is None or self ._max_value is None :
525520 # this is the first point, nothing to do, just set the range
526521 self ._min_value = np .array (new_output )
527522 self ._max_value = np .array (new_output )
528- self ._old_scale = self ._scale
523+ self ._old_scale = self ._scale or 1
529524 return False
530525
531526 # if range in one or more directions is doubled, then update all losses
532527 self ._min_value = np .minimum (self ._min_value , new_output )
533528 self ._max_value = np .maximum (self ._max_value , new_output )
534529
535- scale_multiplier = 1 / self ._scale
530+ scale_multiplier = 1 / ( self ._scale or 1 )
536531 if isinstance (scale_multiplier , float ):
537532 scale_multiplier = np .array ([scale_multiplier ], dtype = float )
538533
539534 # the maximum absolute value that is in the range. Because this is the
540535 # largest number, this also has the largest absolute numerical error.
541- max_absolute_value_in_range = np .max (np .abs ([self ._min_value , self ._max_value ]), axis = 0 )
536+ max_absolute_value_in_range = np .max (np .abs ([self ._min_value , self ._max_value ]))
542537 # since a float has a relative error of 1e-15, the absolute error is the value * 1e-15
543538 abs_err = 1e-15 * max_absolute_value_in_range
544539 # when scaling the floats, the error gets increased.
0 commit comments