@@ -483,7 +483,7 @@ def _compute_loss(self, simplex):
483483
484484 # scale them to a cube with sides 1
485485 vertices = vertices @ self ._transform
486- values = self ._output_multiplier * values
486+ values = self ._output_multiplier * np . array ( values )
487487
488488 # compute the loss on the scaled simplex
489489 return float (self .loss_per_simplex (vertices , values ))
@@ -513,40 +513,38 @@ def _recompute_all_losses(self):
513513 @property
514514 def _scale (self ):
515515 # get the output scale
516- return np . max ( self ._max_value - self ._min_value )
516+ return self ._max_value - self ._min_value
517517
518518 def _update_range (self , new_output ):
519519 if self ._min_value is None or self ._max_value is None :
520520 # this is the first point, nothing to do, just set the range
521- self ._min_value = np .array (new_output )
522- self ._max_value = np .array (new_output )
521+ self ._min_value = np .min (new_output )
522+ self ._max_value = np .max (new_output )
523523 self ._old_scale = self ._scale or 1
524524 return False
525525
526526 # if range in one or more directions is doubled, then update all losses
527- self ._min_value = np . minimum (self ._min_value , new_output )
528- self ._max_value = np . maximum (self ._max_value , new_output )
527+ self ._min_value = min (self ._min_value , np . min ( new_output ) )
528+ self ._max_value = max (self ._max_value , np . max ( new_output ) )
529529
530530 scale_multiplier = 1 / (self ._scale or 1 )
531- if isinstance (scale_multiplier , float ):
532- scale_multiplier = np .array ([scale_multiplier ], dtype = float )
533531
534532 # the maximum absolute value that is in the range. Because this is the
535533 # largest number, this also has the largest absolute numerical error.
536- max_absolute_value_in_range = np .max (np .abs ([self ._min_value , self ._max_value ]))
534+ max_absolute_value_in_range = max (abs (self ._min_value ),
535+ abs (self ._max_value ))
537536 # since a float has a relative error of 1e-15, the absolute error is the value * 1e-15
538537 abs_err = 1e-15 * max_absolute_value_in_range
539538 # when scaling the floats, the error gets increased.
540539 scaled_err = abs_err * scale_multiplier
541540
542- allowed_numerical_error = 1e-2
543-
544541 # do not scale along the axis if the numerical error gets too big
545- scale_multiplier [scaled_err > allowed_numerical_error ] = 1
542+ if scaled_err > 1e-2 : # allowed_numerical_error = 1e-2
543+ scale_multiplier = 1
546544
547545 self ._output_multiplier = scale_multiplier
548546
549- scale_factor = np . max ( self ._scale / self ._old_scale )
547+ scale_factor = self ._scale / self ._old_scale
550548 if scale_factor > self ._recompute_losses_factor :
551549 self ._old_scale = self ._scale
552550 self ._recompute_all_losses ()
0 commit comments