Skip to content

Commit 722f61a

Browse files
committed
implement 'npoints' strategy for the 'BalancingLearner'
1 parent 2208d1a commit 722f61a

1 file changed

Lines changed: 23 additions & 6 deletions

File tree

adaptive/learner/balancing_learner.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ class BalancingLearner(BaseLearner):
3636
>>> cdims = (['A', 'B'], itertools.product([True, False], [0, 1]))
3737
>>> cdims = (['A', 'B'], [(True, 0), (True, 1),
3838
... (False, 0), (False, 1)])
39-
strategy : 'loss_improvements' (default) or 'loss'
40-
The points that the 'BalancingLearner' choses can be either based on
41-
the best 'loss_improvements' or the smallest total 'loss' of the
42-
child learners.
39+
strategy : 'loss_improvements' (default), 'loss', or 'npoints'
40+
The points that the 'BalancingLearner' choses can be either based on:
41+
the best 'loss_improvements', the smallest total 'loss' of the
42+
child learners, or the number of points per learner, using 'npoints'.
4343
4444
Notes
4545
-----
@@ -73,9 +73,11 @@ def __init__(self, learners, *, cdims=None, strategy='loss_improvements'):
7373
self._ask_and_tell = self._ask_and_tell_based_on_loss_improvements
7474
elif strategy == 'loss':
7575
self._ask_and_tell = self._ask_and_tell_based_on_loss
76+
elif strategy == 'npoints':
77+
self._ask_and_tell = self._ask_and_tell_based_on_npoints
7678
else:
77-
raise ValueError('Only strategy="loss_improvements" or'
78-
' strategy="loss" is implemented.')
79+
raise ValueError('Only strategy="loss_improvements",'
80+
' strategy="loss", or strategy="npoints" is implemented.')
7981

8082
def _ask_and_tell_based_on_loss_improvements(self, n):
8183
points = []
@@ -109,6 +111,21 @@ def _ask_and_tell_based_on_loss(self, n):
109111
loss_improvements.append(ls[0])
110112
return points, loss_improvements
111113

114+
def _ask_and_tell_based_on_npoints(self, n):
115+
points = []
116+
loss_improvements = []
117+
npoints = [l.npoints + len(l.pending_points)
118+
for l in self.learners]
119+
n_left = n
120+
while n_left > 0:
121+
i = np.argmin(npoints)
122+
xs, ls = self.learners[i].ask(1)
123+
npoints[i] += 1
124+
n_left -= 1
125+
points.append((i, xs[0]))
126+
loss_improvements.append(ls[0])
127+
return points, loss_improvements
128+
112129
def ask(self, n, tell_pending=True):
113130
"""Chose points for learners."""
114131
if not tell_pending:

0 commit comments

Comments
 (0)