|
| 1 | +from collections import defaultdict |
1 | 2 | from copy import deepcopy |
2 | 3 | from math import hypot |
3 | 4 |
|
@@ -336,26 +337,20 @@ def tell_many(self, xs, ys): |
336 | 337 | "x value out of bounds, " |
337 | 338 | "remove x or enlarge the bounds of the learner" |
338 | 339 | ) |
339 | | - x_old = np.inf |
340 | | - ys_old = [] |
| 340 | + |
| 341 | + # Create a mapping of points to a list of samples |
| 342 | + mapping = defaultdict(list) |
341 | 343 | for x, y in zip(xs, ys): |
342 | | - if x == x_old: |
343 | | - # Store the y-values until a new x is found in xs |
344 | | - ys_old.append(y) |
345 | | - else: |
346 | | - if len(ys_old) == 1: |
347 | | - self.tell(x_old, ys_old[0]) |
348 | | - elif len(ys_old) > 1: |
349 | | - # If we stored more than 1 y-value for the previous x, |
350 | | - # use a more efficient routine to tell many samples |
351 | | - # simultaneously, before we move on to a new x |
352 | | - self.tell_many_at_point(x_old, ys_old) |
353 | | - x_old = x |
354 | | - ys_old = [y] |
355 | | - if len(ys_old) == 1: |
356 | | - self.tell(x_old, ys_old[0]) |
357 | | - elif len(ys_old) > 1: |
358 | | - self.tell_many_at_point(x_old, ys_old) |
| 344 | + mapping[x].append(y) |
| 345 | + |
| 346 | + for x, ys in mapping.items(): |
| 347 | + if len(ys) == 1: |
| 348 | + self.tell(x, ys[0]) |
| 349 | + elif len(ys) > 1: |
| 350 | + # If we stored more than 1 y-value for the previous x, |
| 351 | + # use a more efficient routine to tell many samples |
| 352 | + # simultaneously, before we move on to a new x |
| 353 | + self.tell_many_at_point(x, ys) |
359 | 354 |
|
360 | 355 | def tell_many_at_point(self, x, ys): |
361 | 356 | """Tell the learner about many samples at a certain location x. |
|
0 commit comments