@@ -323,83 +323,76 @@ def _calc_error_in_mean(self, ys, y_avg, n):
323323 return t_student * (variance_in_mean / n ) ** 0.5
324324
325325 def tell_many (self , xs , ys ):
326- """Tell the learner about some values.
326+ # Check that all x are within the bounds
327+ if not np .prod ([x >= self .bounds [0 ] and x <= self .bounds [1 ] for x in xs ]):
328+ raise ValueError (
329+ "x value out of bounds, "
330+ "remove x or enlarge the bounds of the learner"
331+ )
332+ x_old = np .inf
333+ ys_old = []
334+ for x , y in zip (xs ,ys ):
335+ if x == x_old :
336+ # Store the y-values until a new x is found in xs
337+ ys_old .append (y )
338+ else :
339+ if len (ys_old )== 1 :
340+ self .tell (x_old ,ys_old [0 ])
341+ elif len (ys_old )> 1 :
342+ # If we stored more than 1 y-value for the previous x,
343+ # use a more efficient routine to tell many samples
344+ # simultaneously, before we move on to a new x
345+ self .tell_many_samples (x_old ,ys_old )
346+ x_old = x
347+ ys_old = [y ]
348+ if len (ys_old )== 1 :
349+ self .tell (x_old ,ys_old [0 ])
350+ elif len (ys_old )> 1 :
351+ self .tell_many_samples (x_old ,ys_old )
352+
353+ def tell_many_samples (self , x , ys ):
354+ """Tell the learner about many samples at a certain location x.
327355
328356 Parameters
329357 ----------
330- xs : Iterable of values from the function domain
331- ys : Iterable of lists of values (several data samples)
332- or scalars from the function image (mean values that
333- are directly included in self.data; the number of samples is then
334- assumed to be 1 and the error 0 for all data points; these points
335- will not be included in _rescaled_error_in_mean and therefore
336- will not be resampled)
337-
338- Examples
339- ----------
340- If xs = [0,1] and ys = [[1.2,1.4],[2]], the y-values will be included in
341- self._data_samples.
342- If xs = [0,1] and ys = [1.2,2], the y-values will be directly included
343- in self.data, and the previous self.data[0] and self.data[1] will be erased.
358+ x : Value from the function domain
359+ ys : List of data samples at x
344360 """
345- # Check that all x are within the bounds
346- assert np .prod ([x >= self .bounds [0 ] and x <= self .bounds [1 ] for x in xs ]), 'x value out of bounds: remove x or enlarge the bounds of the learner'
347- # If data_samples is given:
348- if isinstance (ys [0 ], list ):
349- for x , ys_ in zip (xs , ys ):
350- y_avg = np .mean (ys_ )
351- if x not in self .data :
352- y = ys_ .pop (0 )
353- self ._update_data (x , y , "new" )
354- self ._update_data_structures (x , y , "new" )
355- if len (ys_ ):
356- self .data [x ] = y_avg
357- self ._data_samples .update ({x : ys_ + self ._data_samples [x ]})
358- n = len (self ._data_samples [x ])
359- self ._number_samples [x ] = n
360- # self._update_data(x,y,"new") included the point
361- # in _undersampled_points. We remove it if there are
362- # more than min_samples samples, disregarding neighbor_sampling.
363- if n > self .min_samples :
364- self ._undersampled_points .discard (x )
365- self ._error_in_mean [x ] = self ._calc_error_in_mean (self ._data_samples [x ], y_avg , n )
366- self ._update_distances (x )
367- self ._update_rescaled_error_in_mean (x , "resampled" )
368- if self ._error_in_mean [x ] <= self .min_error or n >= self .max_samples :
369- self ._rescaled_error_in_mean .pop (x , None )
370-
371- super ()._update_scale (x , y_avg )
372- self ._update_losses_resampling (x , real = True )
373- if self ._scale [1 ] > self ._recompute_losses_factor * self ._oldscale [1 ]:
374- for interval in reversed (self .losses ):
375- self ._update_interpolated_loss_in_interval (* interval )
376- self ._oldscale = deepcopy (self ._scale )
377-
378- # If data is given:
379- else :
380- security_question = input ('Function values given as scalars instead of lists. This will potentially overwrite the learner data. Continue? (y/n)' )
381- if security_question in ['y' ,'Y' ,'yes' ,'YES' ,'Yes' ]:
382- for x , y in zip (xs , ys ):
383- self .data [x ] = y
384- self ._update_data_structures (x , y , "new" )
385- self ._data_samples .update ({x : [y ]})
386- self ._number_samples [x ] = 1
387- # self._update_data(x,y,"new") included the point
388- # in _undersampled_points. We remove it since there is
389- # no need for resampling it (same applies to self._rescaled_error_in_mean).
390- self ._undersampled_points .discard (x )
391- self ._error_in_mean [x ] = 0
392- self ._update_distances (x )
393- self ._update_rescaled_error_in_mean (x , "resampled" )
394- self ._rescaled_error_in_mean .pop (x , None )
395- super ()._update_scale (x , y )
396- self ._update_losses_resampling (x , real = True )
397- if self ._scale [1 ] > self ._recompute_losses_factor * self ._oldscale [1 ]:
398- for interval in reversed (self .losses ):
399- self ._update_interpolated_loss_in_interval (* interval )
400- self ._oldscale = deepcopy (self ._scale )
401- else :
402- print ('tell_many() aborted.' )
361+ # Check x is within the bounds
362+ if not np .prod (x >= self .bounds [0 ] and x <= self .bounds [1 ]):
363+ raise ValueError (
364+ "x value out of bounds, "
365+ "remove x or enlarge the bounds of the learner"
366+ )
367+
368+ y_avg = np .mean (ys )
369+ # If x is a new point:
370+ if x not in self .data :
371+ y = ys .pop (0 )
372+ self ._update_data (x , y , "new" )
373+ self ._update_data_structures (x , y , "new" )
374+ # If x is not a new point or if there were more than 1 sample in ys:
375+ if len (ys ):
376+ self .data [x ] = y_avg
377+ self ._data_samples .update ({x : ys + self ._data_samples [x ]})
378+ n = len (self ._data_samples [x ])
379+ self ._number_samples [x ] = n
380+ # self._update_data(x,y,"new") included the point
381+ # in _undersampled_points. We remove it if there are
382+ # more than min_samples samples, disregarding neighbor_sampling.
383+ if n > self .min_samples :
384+ self ._undersampled_points .discard (x )
385+ self ._error_in_mean [x ] = self ._calc_error_in_mean (self ._data_samples [x ], y_avg , n )
386+ self ._update_distances (x )
387+ self ._update_rescaled_error_in_mean (x , "resampled" )
388+ if self ._error_in_mean [x ] <= self .min_error or n >= self .max_samples :
389+ self ._rescaled_error_in_mean .pop (x , None )
390+ super ()._update_scale (x , y_avg )
391+ self ._update_losses_resampling (x , real = True )
392+ if self ._scale [1 ] > self ._recompute_losses_factor * self ._oldscale [1 ]:
393+ for interval in reversed (self .losses ):
394+ self ._update_interpolated_loss_in_interval (* interval )
395+ self ._oldscale = deepcopy (self ._scale )
403396
404397 def plot (self ):
405398 """Returns a plot of the evaluated data with error bars (not implemented
0 commit comments