44from functools import partial
55from operator import itemgetter
66
7+ import numpy as np
8+
79from .base_learner import BaseLearner
810from ..notebook_integration import ensure_holoviews
911from ..utils import cache_latest , named_product , restore
@@ -34,6 +36,10 @@ class BalancingLearner(BaseLearner):
3436 >>> cdims = (['A', 'B'], itertools.product([True, False], [0, 1]))
3537 >>> cdims = (['A', 'B'], [(True, 0), (True, 1),
3638 ... (False, 0), (False, 1)])
39+ strategy : 'loss_improvements' (default) or 'loss'
40+ The points that the 'BalancingLearner' choses can be either based on
41+ the best 'loss_improvements' or the smallest total 'loss' of the
42+ child learners.
3743
3844 Notes
3945 -----
@@ -46,7 +52,7 @@ class BalancingLearner(BaseLearner):
4652 undefined way.
4753 """
4854
49- def __init__ (self , learners , * , cdims = None ):
55+ def __init__ (self , learners , * , cdims = None , strategy = 'loss_improvements' ):
5056 self .learners = learners
5157
5258 # Naively we would make 'function' a method, but this causes problems
@@ -61,9 +67,17 @@ def __init__(self, learners, *, cdims=None):
6167
6268 if len (set (learner .__class__ for learner in self .learners )) > 1 :
6369 raise TypeError ('A BalacingLearner can handle only one type'
64- 'of learners.' )
70+ ' of learners.' )
71+
72+ if strategy == 'loss_improvements' :
73+ self ._ask_and_tell = self ._ask_and_tell_based_on_loss_improvements
74+ elif strategy == 'loss' :
75+ self ._ask_and_tell = self ._ask_and_tell_based_on_loss
76+ else :
77+ raise ValueError ('Only strategy="loss_improvements" or'
78+ ' strategy="loss" is implemented.' )
6579
66- def _ask_and_tell (self , n ):
80+ def _ask_and_tell_based_on_loss_improvements (self , n ):
6781 points = []
6882 loss_improvements = []
6983 for _ in range (n ):
@@ -84,6 +98,17 @@ def _ask_and_tell(self, n):
8498
8599 return points , loss_improvements
86100
101+ def _ask_and_tell_based_on_loss (self , n ):
102+ points = []
103+ loss_improvements = []
104+ for _ in range (n ):
105+ losses = self .losses (real = False )
106+ max_ind = np .argmax (losses )
107+ xs , ls = self .learners [max_ind ].ask (1 )
108+ points .append ((max_ind , xs [0 ]))
109+ loss_improvements .append (ls [0 ])
110+ return points , loss_improvements
111+
87112 def ask (self , n , tell_pending = True ):
88113 """Chose points for learners."""
89114 if not tell_pending :
0 commit comments