Skip to content

Commit 36c4b60

Browse files
AlvaroGIbasnijholt
authored andcommitted
tell_many() redesigned
This method now matches the definition from the BaseLearner. It provides a computational efficiency in some scenarios (see the comments in the code), otherwise it just performs a loop with a tell(x,y)
1 parent f5631ac commit 36c4b60

1 file changed

Lines changed: 66 additions & 73 deletions

File tree

adaptive/learner/average_learner1D.py

Lines changed: 66 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -323,83 +323,76 @@ def _calc_error_in_mean(self, ys, y_avg, n):
323323
return t_student * (variance_in_mean / n) ** 0.5
324324

325325
def tell_many(self, xs, ys):
326-
"""Tell the learner about some values.
326+
# Check that all x are within the bounds
327+
if not np.prod([x>=self.bounds[0] and x<=self.bounds[1] for x in xs]):
328+
raise ValueError(
329+
"x value out of bounds, "
330+
"remove x or enlarge the bounds of the learner"
331+
)
332+
x_old = np.inf
333+
ys_old = []
334+
for x, y in zip(xs,ys):
335+
if x == x_old:
336+
# Store the y-values until a new x is found in xs
337+
ys_old.append(y)
338+
else:
339+
if len(ys_old)==1:
340+
self.tell(x_old,ys_old[0])
341+
elif len(ys_old)>1:
342+
# If we stored more than 1 y-value for the previous x,
343+
# use a more efficient routine to tell many samples
344+
# simultaneously, before we move on to a new x
345+
self.tell_many_samples(x_old,ys_old)
346+
x_old = x
347+
ys_old = [y]
348+
if len(ys_old)==1:
349+
self.tell(x_old,ys_old[0])
350+
elif len(ys_old)>1:
351+
self.tell_many_samples(x_old,ys_old)
352+
353+
def tell_many_samples(self, x, ys):
354+
"""Tell the learner about many samples at a certain location x.
327355
328356
Parameters
329357
----------
330-
xs : Iterable of values from the function domain
331-
ys : Iterable of lists of values (several data samples)
332-
or scalars from the function image (mean values that
333-
are directly included in self.data; the number of samples is then
334-
assumed to be 1 and the error 0 for all data points; these points
335-
will not be included in _rescaled_error_in_mean and therefore
336-
will not be resampled)
337-
338-
Examples
339-
----------
340-
If xs = [0,1] and ys = [[1.2,1.4],[2]], the y-values will be included in
341-
self._data_samples.
342-
If xs = [0,1] and ys = [1.2,2], the y-values will be directly included
343-
in self.data, and the previous self.data[0] and self.data[1] will be erased.
358+
x : Value from the function domain
359+
ys : List of data samples at x
344360
"""
345-
# Check that all x are within the bounds
346-
assert np.prod([x>=self.bounds[0] and x<=self.bounds[1] for x in xs]), 'x value out of bounds: remove x or enlarge the bounds of the learner'
347-
# If data_samples is given:
348-
if isinstance(ys[0], list):
349-
for x, ys_ in zip(xs, ys):
350-
y_avg = np.mean(ys_)
351-
if x not in self.data:
352-
y = ys_.pop(0)
353-
self._update_data(x, y, "new")
354-
self._update_data_structures(x, y, "new")
355-
if len(ys_):
356-
self.data[x] = y_avg
357-
self._data_samples.update({x: ys_+self._data_samples[x]})
358-
n = len(self._data_samples[x])
359-
self._number_samples[x] = n
360-
# self._update_data(x,y,"new") included the point
361-
# in _undersampled_points. We remove it if there are
362-
# more than min_samples samples, disregarding neighbor_sampling.
363-
if n > self.min_samples:
364-
self._undersampled_points.discard(x)
365-
self._error_in_mean[x] = self._calc_error_in_mean(self._data_samples[x], y_avg, n)
366-
self._update_distances(x)
367-
self._update_rescaled_error_in_mean(x, "resampled")
368-
if self._error_in_mean[x] <= self.min_error or n >= self.max_samples:
369-
self._rescaled_error_in_mean.pop(x, None)
370-
371-
super()._update_scale(x, y_avg)
372-
self._update_losses_resampling(x, real=True)
373-
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
374-
for interval in reversed(self.losses):
375-
self._update_interpolated_loss_in_interval(*interval)
376-
self._oldscale = deepcopy(self._scale)
377-
378-
# If data is given:
379-
else:
380-
security_question = input('Function values given as scalars instead of lists. This will potentially overwrite the learner data. Continue? (y/n)')
381-
if security_question in ['y','Y','yes','YES','Yes']:
382-
for x, y in zip(xs, ys):
383-
self.data[x] = y
384-
self._update_data_structures(x, y, "new")
385-
self._data_samples.update({x: [y]})
386-
self._number_samples[x] = 1
387-
# self._update_data(x,y,"new") included the point
388-
# in _undersampled_points. We remove it since there is
389-
# no need for resampling it (same applies to self._rescaled_error_in_mean).
390-
self._undersampled_points.discard(x)
391-
self._error_in_mean[x] = 0
392-
self._update_distances(x)
393-
self._update_rescaled_error_in_mean(x, "resampled")
394-
self._rescaled_error_in_mean.pop(x, None)
395-
super()._update_scale(x, y)
396-
self._update_losses_resampling(x, real=True)
397-
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
398-
for interval in reversed(self.losses):
399-
self._update_interpolated_loss_in_interval(*interval)
400-
self._oldscale = deepcopy(self._scale)
401-
else:
402-
print('tell_many() aborted.')
361+
# Check x is within the bounds
362+
if not np.prod(x>=self.bounds[0] and x<=self.bounds[1]):
363+
raise ValueError(
364+
"x value out of bounds, "
365+
"remove x or enlarge the bounds of the learner"
366+
)
367+
368+
y_avg = np.mean(ys)
369+
# If x is a new point:
370+
if x not in self.data:
371+
y = ys.pop(0)
372+
self._update_data(x, y, "new")
373+
self._update_data_structures(x, y, "new")
374+
# If x is not a new point or if there were more than 1 sample in ys:
375+
if len(ys):
376+
self.data[x] = y_avg
377+
self._data_samples.update({x: ys+self._data_samples[x]})
378+
n = len(self._data_samples[x])
379+
self._number_samples[x] = n
380+
# self._update_data(x,y,"new") included the point
381+
# in _undersampled_points. We remove it if there are
382+
# more than min_samples samples, disregarding neighbor_sampling.
383+
if n > self.min_samples:
384+
self._undersampled_points.discard(x)
385+
self._error_in_mean[x] = self._calc_error_in_mean(self._data_samples[x], y_avg, n)
386+
self._update_distances(x)
387+
self._update_rescaled_error_in_mean(x, "resampled")
388+
if self._error_in_mean[x] <= self.min_error or n >= self.max_samples:
389+
self._rescaled_error_in_mean.pop(x, None)
390+
super()._update_scale(x, y_avg)
391+
self._update_losses_resampling(x, real=True)
392+
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
393+
for interval in reversed(self.losses):
394+
self._update_interpolated_loss_in_interval(*interval)
395+
self._oldscale = deepcopy(self._scale)
403396

404397
def plot(self):
405398
"""Returns a plot of the evaluated data with error bars (not implemented

0 commit comments

Comments
 (0)