Skip to content

Commit aed824a

Browse files
committed
simplify tell_many
1 parent 4434ccb commit aed824a

1 file changed

Lines changed: 14 additions & 19 deletions

File tree

adaptive/learner/average_learner1D.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from copy import deepcopy
23
from math import hypot
34

@@ -336,26 +337,20 @@ def tell_many(self, xs, ys):
336337
"x value out of bounds, "
337338
"remove x or enlarge the bounds of the learner"
338339
)
339-
x_old = np.inf
340-
ys_old = []
340+
341+
# Create a mapping of points to a list of samples
342+
mapping = defaultdict(list)
341343
for x, y in zip(xs, ys):
342-
if x == x_old:
343-
# Store the y-values until a new x is found in xs
344-
ys_old.append(y)
345-
else:
346-
if len(ys_old) == 1:
347-
self.tell(x_old, ys_old[0])
348-
elif len(ys_old) > 1:
349-
# If we stored more than 1 y-value for the previous x,
350-
# use a more efficient routine to tell many samples
351-
# simultaneously, before we move on to a new x
352-
self.tell_many_at_point(x_old, ys_old)
353-
x_old = x
354-
ys_old = [y]
355-
if len(ys_old) == 1:
356-
self.tell(x_old, ys_old[0])
357-
elif len(ys_old) > 1:
358-
self.tell_many_at_point(x_old, ys_old)
344+
mapping[x].append(y)
345+
346+
for x, ys in mapping.items():
347+
if len(ys) == 1:
348+
self.tell(x, ys[0])
349+
elif len(ys) > 1:
350+
# If we stored more than 1 y-value for the previous x,
351+
# use a more efficient routine to tell many samples
352+
# simultaneously, before we move on to a new x
353+
self.tell_many_at_point(x, ys)
359354

360355
def tell_many_at_point(self, x, ys):
361356
"""Tell the learner about many samples at a certain location x.

0 commit comments

Comments
 (0)