@@ -36,10 +36,10 @@ class BalancingLearner(BaseLearner):
3636 >>> cdims = (['A', 'B'], itertools.product([True, False], [0, 1]))
3737 >>> cdims = (['A', 'B'], [(True, 0), (True, 1),
3838 ... (False, 0), (False, 1)])
39- strategy : 'loss_improvements' (default) or 'loss '
40- The points that the 'BalancingLearner' choses can be either based on
41- the best 'loss_improvements' or the smallest total 'loss' of the
42- child learners.
39+ strategy : 'loss_improvements' (default), 'loss', or 'npoints '
40+ The points that the 'BalancingLearner' choses can be either based on:
41+ the best 'loss_improvements', the smallest total 'loss' of the
42+ child learners, or the number of points per learner, using 'npoints' .
4343
4444 Notes
4545 -----
@@ -73,9 +73,11 @@ def __init__(self, learners, *, cdims=None, strategy='loss_improvements'):
7373 self ._ask_and_tell = self ._ask_and_tell_based_on_loss_improvements
7474 elif strategy == 'loss' :
7575 self ._ask_and_tell = self ._ask_and_tell_based_on_loss
76+ elif strategy == 'npoints' :
77+ self ._ask_and_tell = self ._ask_and_tell_based_on_npoints
7678 else :
77- raise ValueError ('Only strategy="loss_improvements" or '
78- ' strategy="loss" is implemented.' )
79+ raise ValueError ('Only strategy="loss_improvements", '
80+ ' strategy="loss", or strategy="npoints " is implemented.' )
7981
8082 def _ask_and_tell_based_on_loss_improvements (self , n ):
8183 points = []
@@ -109,6 +111,21 @@ def _ask_and_tell_based_on_loss(self, n):
109111 loss_improvements .append (ls [0 ])
110112 return points , loss_improvements
111113
114+ def _ask_and_tell_based_on_npoints (self , n ):
115+ points = []
116+ loss_improvements = []
117+ npoints = [l .npoints + len (l .pending_points )
118+ for l in self .learners ]
119+ n_left = n
120+ while n_left > 0 :
121+ i = np .argmin (npoints )
122+ xs , ls = self .learners [i ].ask (1 )
123+ npoints [i ] += 1
124+ n_left -= 1
125+ points .append ((i , xs [0 ]))
126+ loss_improvements .append (ls [0 ])
127+ return points , loss_improvements
128+
112129 def ask (self , n , tell_pending = True ):
113130 """Chose points for learners."""
114131 if not tell_pending :
0 commit comments