Skip to content

Commit 7c2b0dd

Browse files
committed
run 'pre-commit run --all'
1 parent 16c13a4 commit 7c2b0dd

1 file changed

Lines changed: 20 additions & 18 deletions

File tree

adaptive/learner/average_learner1D.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,16 @@ def ask(self, n, tell_pending=True):
128128

129129
def _ask_for_more_samples(self, x, n):
130130
"""When asking for n points, the learner returns n times an existing point
131-
to be resampled, since in general n << min_samples and this point will
132-
need to be resampled many more times"""
131+
to be resampled, since in general n << min_samples and this point will
132+
need to be resampled many more times"""
133133
points = [x] * n
134134
loss_improvements = [0] * n # We set the loss_improvements of resamples to 0
135135
return points, loss_improvements
136136

137137
def _ask_for_new_point(self, n):
138138
"""When asking for n new points, the learner returns n times a single
139-
new point, since in general n << min_samples and this point will need
140-
to be resampled many more times"""
139+
new point, since in general n << min_samples and this point will need
140+
to be resampled many more times"""
141141
points, loss_improvements = self._ask_points_without_adding(1)
142142
points = points * n
143143
loss_improvements = loss_improvements + [0] * (n - 1)
@@ -171,7 +171,7 @@ def tell(self, x, y):
171171

172172
def _update_rescaled_error_in_mean(self, x, point_type):
173173
"""Updates self._rescaled_error_in_mean; point_type must be "new" or
174-
"resampled". """
174+
"resampled"."""
175175
#  Update neighbors
176176
x_left, x_right = self.neighbors[x]
177177
dists = self._distances
@@ -324,31 +324,31 @@ def _calc_error_in_mean(self, ys, y_avg, n):
324324

325325
def tell_many(self, xs, ys):
326326
# Check that all x are within the bounds
327-
if not np.prod([x>=self.bounds[0] and x<=self.bounds[1] for x in xs]):
327+
if not np.prod([x >= self.bounds[0] and x <= self.bounds[1] for x in xs]):
328328
raise ValueError(
329329
"x value out of bounds, "
330330
"remove x or enlarge the bounds of the learner"
331331
)
332332
x_old = np.inf
333333
ys_old = []
334-
for x, y in zip(xs,ys):
334+
for x, y in zip(xs, ys):
335335
if x == x_old:
336336
# Store the y-values until a new x is found in xs
337337
ys_old.append(y)
338338
else:
339-
if len(ys_old)==1:
340-
self.tell(x_old,ys_old[0])
341-
elif len(ys_old)>1:
339+
if len(ys_old) == 1:
340+
self.tell(x_old, ys_old[0])
341+
elif len(ys_old) > 1:
342342
# If we stored more than 1 y-value for the previous x,
343343
# use a more efficient routine to tell many samples
344344
# simultaneously, before we move on to a new x
345-
self.tell_many_samples(x_old,ys_old)
345+
self.tell_many_samples(x_old, ys_old)
346346
x_old = x
347347
ys_old = [y]
348-
if len(ys_old)==1:
349-
self.tell(x_old,ys_old[0])
350-
elif len(ys_old)>1:
351-
self.tell_many_samples(x_old,ys_old)
348+
if len(ys_old) == 1:
349+
self.tell(x_old, ys_old[0])
350+
elif len(ys_old) > 1:
351+
self.tell_many_samples(x_old, ys_old)
352352

353353
def tell_many_samples(self, x, ys):
354354
"""Tell the learner about many samples at a certain location x.
@@ -359,7 +359,7 @@ def tell_many_samples(self, x, ys):
359359
ys : List of data samples at x
360360
"""
361361
# Check x is within the bounds
362-
if not np.prod(x>=self.bounds[0] and x<=self.bounds[1]):
362+
if not np.prod(x >= self.bounds[0] and x <= self.bounds[1]):
363363
raise ValueError(
364364
"x value out of bounds, "
365365
"remove x or enlarge the bounds of the learner"
@@ -374,15 +374,17 @@ def tell_many_samples(self, x, ys):
374374
# If x is not a new point or if there were more than 1 sample in ys:
375375
if len(ys):
376376
self.data[x] = y_avg
377-
self._data_samples.update({x: ys+self._data_samples[x]})
377+
self._data_samples.update({x: ys + self._data_samples[x]})
378378
n = len(self._data_samples[x])
379379
self._number_samples[x] = n
380380
# self._update_data(x,y,"new") included the point
381381
# in _undersampled_points. We remove it if there are
382382
# more than min_samples samples, disregarding neighbor_sampling.
383383
if n > self.min_samples:
384384
self._undersampled_points.discard(x)
385-
self._error_in_mean[x] = self._calc_error_in_mean(self._data_samples[x], y_avg, n)
385+
self._error_in_mean[x] = self._calc_error_in_mean(
386+
self._data_samples[x], y_avg, n
387+
)
386388
self._update_distances(x)
387389
self._update_rescaled_error_in_mean(x, "resampled")
388390
if self._error_in_mean[x] <= self.min_error or n >= self.max_samples:

0 commit comments

Comments
 (0)