Skip to content

Commit 9aa93ec

Browse files
committed
rename _error_in_mean -> error and _rescaled_error_in_mean -> rescaled_error
1 parent 9c55259 commit 9aa93ec

1 file changed

Lines changed: 35 additions & 38 deletions

File tree

adaptive/learner/average_learner1D.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class AverageLearner1D(Learner1D):
2222
We strongly recommend 0 < delta <= 1.
2323
alpha : float (0 < alpha < 1)
2424
The true value of the function at x is within the confidence interval
25-
[self.data[x] - self._error_in_mean[x], self.data[x] +
26-
self._error_in_mean[x]] with probability 1-2*alpha.
25+
[self.data[x] - self.error[x], self.data[x] +
26+
self.error[x]] with probability 1-2*alpha.
2727
We recommend to keep alpha=0.005.
2828
neighbor_sampling : float (0 < neighbor_sampling <= 1)
2929
Each new point is initially sampled at least a (neighbor_sampling*100)%
@@ -36,9 +36,9 @@ class AverageLearner1D(Learner1D):
3636
min_error : float (min_error >= 0)
3737
Minimum size of the confidence intervals. The true value of the
3838
function at x is within the confidence interval [self.data[x] -
39-
self._error_in_mean[x], self.data[x] + self._error_in_mean[x]] with
39+
self.error[x], self.data[x] + self.error[x]] with
4040
probability 1-2*alpha.
41-
If self._error_in_mean[x] < min_error, then x will not be resampled
41+
If self.error[x] < min_error, then x will not be resampled
4242
anymore, i.e., the smallest confidence interval at x is
4343
[self.data[x] - min_error, self.data[x] + min_error].
4444
"""
@@ -86,12 +86,12 @@ def __init__(
8686
self._undersampled_points = set()
8787
# Contains the error in the estimate of the
8888
# mean at each point x in the form {x0: error(x0), ...}
89-
self._error_in_mean = decreasing_dict()
89+
self.error = decreasing_dict()
9090
#  Distance between two neighboring points in the
9191
# form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
9292
self._distances = decreasing_dict()
93-
# {xii: _error_in_mean[xii]/min(_distances[xi], _distances[xii], ...}
94-
self._rescaled_error_in_mean = decreasing_dict()
93+
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
94+
self.rescaled_error = decreasing_dict()
9595

9696
@property
9797
def total_samples(self):
@@ -109,9 +109,9 @@ def ask(self, n, tell_pending=True):
109109
points, loss_improvements = self._ask_for_new_point(n)
110110
#  Else, check the resampling condition
111111
else:
112-
if len(self._rescaled_error_in_mean):
113-
# This is in case _rescaled_error_in_mean is empty (e.g. when sigma=0)
114-
x, resc_error = self._rescaled_error_in_mean.peekitem(0)
112+
if len(self.rescaled_error):
113+
# This is in case rescaled_error is empty (e.g. when sigma=0)
114+
x, resc_error = self.rescaled_error.peekitem(0)
115115
# Resampling condition
116116
if resc_error > self.delta:
117117
points, loss_improvements = self._ask_for_more_samples(x, n)
@@ -169,9 +169,14 @@ def tell(self, x, y):
169169
self._update_data_structures(x, y, "resampled")
170170
self.pending_points.discard(x)
171171

172-
def _update_rescaled_error_in_mean(self, x, point_type):
173-
"""Updates self._rescaled_error_in_mean; point_type must be "new" or
174-
"resampled"."""
172+
def _update_rescaled_error_in_mean(self, x, point_type: str) -> None:
173+
"""Updates self.rescaled_error.
174+
175+
Parameters
176+
----------
177+
point_type : str
178+
Must be either "new" or "resampled".
179+
"""
175180
#  Update neighbors
176181
x_left, x_right = self.neighbors[x]
177182
dists = self._distances
@@ -182,38 +187,34 @@ def _update_rescaled_error_in_mean(self, x, point_type):
182187
d_left = dists[x]
183188
else:
184189
d_left = dists[x_left]
185-
if x_left in self._rescaled_error_in_mean:
190+
if x_left in self.rescaled_error:
186191
xll = self.neighbors[x_left][0]
187192
norm = dists[x_left] if xll is None else min(dists[xll], dists[x_left])
188-
self._rescaled_error_in_mean[x_left] = (
189-
self._error_in_mean[x_left] / norm
190-
)
193+
self.rescaled_error[x_left] = self.error[x_left] / norm
191194

192195
if x_right is None:
193196
d_right = dists[x_left]
194197
else:
195198
d_right = dists[x]
196-
if x_right in self._rescaled_error_in_mean:
199+
if x_right in self.rescaled_error:
197200
xrr = self.neighbors[x_right][1]
198201
norm = dists[x] if xrr is None else min(dists[x], dists[x_right])
199-
self._rescaled_error_in_mean[x_right] = (
200-
self._error_in_mean[x_right] / norm
201-
)
202+
self.rescaled_error[x_right] = self.error[x_right] / norm
202203

203204
# Update x
204205
if point_type == "resampled":
205206
norm = min(d_left, d_right)
206-
self._rescaled_error_in_mean[x] = self._error_in_mean[x] / norm
207+
self.rescaled_error[x] = self.error[x] / norm
207208

208-
def _update_data(self, x, y, point_type):
209+
def _update_data(self, x, y, point_type: str):
209210
if point_type == "new":
210211
self.data[x] = y
211212
elif point_type == "resampled":
212213
n = len(self._data_samples[x])
213214
new_average = self.data[x] * n / (n + 1) + y / (n + 1)
214215
self.data[x] = new_average
215216

216-
def _update_data_structures(self, x, y, point_type):
217+
def _update_data_structures(self, x, y, point_type: str):
217218
if point_type == "new":
218219
self._data_samples[x] = [y]
219220

@@ -233,8 +234,8 @@ def _update_data_structures(self, x, y, point_type):
233234

234235
self._number_samples[x] = 1
235236
self._undersampled_points.add(x)
236-
self._error_in_mean[x] = np.inf
237-
self._rescaled_error_in_mean[x] = np.inf
237+
self.error[x] = np.inf
238+
self.rescaled_error[x] = np.inf
238239
self._update_distances(x)
239240
self._update_rescaled_error_in_mean(x, "new")
240241

@@ -261,12 +262,12 @@ def _update_data_structures(self, x, y, point_type):
261262
# the mean value lies within the correct interval of confidence
262263
y_avg = self.data[x]
263264
ys = self._data_samples[x]
264-
self._error_in_mean[x] = self._calc_error_in_mean(ys, y_avg, n)
265+
self.error[x] = self._calc_error_in_mean(ys, y_avg, n)
265266
self._update_distances(x)
266267
self._update_rescaled_error_in_mean(x, "resampled")
267268

268-
if self._error_in_mean[x] <= self.min_error or n >= self.max_samples:
269-
self._rescaled_error_in_mean.pop(x, None)
269+
if self.error[x] <= self.min_error or n >= self.max_samples:
270+
self.rescaled_error.pop(x, None)
270271

271272
# We also need to update scale and losses
272273
super()._update_scale(x, y)
@@ -382,13 +383,11 @@ def tell_many_samples(self, x, ys):
382383
# more than min_samples samples, disregarding neighbor_sampling.
383384
if n > self.min_samples:
384385
self._undersampled_points.discard(x)
385-
self._error_in_mean[x] = self._calc_error_in_mean(
386-
self._data_samples[x], y_avg, n
387-
)
386+
self.error[x] = self._calc_error_in_mean(self._data_samples[x], y_avg, n)
388387
self._update_distances(x)
389388
self._update_rescaled_error_in_mean(x, "resampled")
390-
if self._error_in_mean[x] <= self.min_error or n >= self.max_samples:
391-
self._rescaled_error_in_mean.pop(x, None)
389+
if self.error[x] <= self.min_error or n >= self.max_samples:
390+
self.rescaled_error.pop(x, None)
392391
super()._update_scale(x, y_avg)
393392
self._update_losses_resampling(x, real=True)
394393
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
@@ -402,7 +401,7 @@ def plot(self):
402401
403402
Returns
404403
-------
405-
plot : `holoviews.element.Scatter * holoviews.element.ErroBars *
404+
plot : `holoviews.element.Scatter * holoviews.element.ErrorBars *
406405
holoviews.element.Path`
407406
Plot of the evaluated data.
408407
"""
@@ -412,9 +411,7 @@ def plot(self):
412411
elif not self.vdim > 1:
413412
xs, ys = zip(*sorted(self.data.items()))
414413
scatter = hv.Scatter(self.data)
415-
error = hv.ErrorBars(
416-
[(x, self.data[x], self._error_in_mean[x]) for x in self.data]
417-
)
414+
error = hv.ErrorBars([(x, self.data[x], self.error[x]) for x in self.data])
418415
line = hv.Path((xs, ys))
419416
p = scatter * error * line
420417
else:

0 commit comments

Comments
 (0)