Skip to content

Commit 3696d4c

Browse files
jhoofwijkbasnijholt
authored andcommitted
make a BalancingLearner strategy that compares the total loss rather than loss improvement
1 parent 953ff84 commit 3696d4c

1 file changed

Lines changed: 28 additions & 3 deletions

File tree

adaptive/learner/balancing_learner.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from functools import partial
55
from operator import itemgetter
66

7+
import numpy as np
8+
79
from .base_learner import BaseLearner
810
from ..notebook_integration import ensure_holoviews
911
from ..utils import cache_latest, named_product, restore
@@ -34,6 +36,10 @@ class BalancingLearner(BaseLearner):
3436
>>> cdims = (['A', 'B'], itertools.product([True, False], [0, 1]))
3537
>>> cdims = (['A', 'B'], [(True, 0), (True, 1),
3638
... (False, 0), (False, 1)])
39+
strategy : 'loss_improvements' (default) or 'loss'
40+
The points that the 'BalancingLearner' choses can be either based on
41+
the best 'loss_improvements' or the smallest total 'loss' of the
42+
child learners.
3743
3844
Notes
3945
-----
@@ -46,7 +52,7 @@ class BalancingLearner(BaseLearner):
4652
undefined way.
4753
"""
4854

49-
def __init__(self, learners, *, cdims=None):
55+
def __init__(self, learners, *, cdims=None, strategy='loss_improvements'):
5056
self.learners = learners
5157

5258
# Naively we would make 'function' a method, but this causes problems
@@ -61,9 +67,17 @@ def __init__(self, learners, *, cdims=None):
6167

6268
if len(set(learner.__class__ for learner in self.learners)) > 1:
6369
raise TypeError('A BalacingLearner can handle only one type'
64-
'of learners.')
70+
' of learners.')
71+
72+
if strategy == 'loss_improvements':
73+
self._ask_and_tell = self._ask_and_tell_based_on_loss_improvements
74+
elif strategy == 'loss':
75+
self._ask_and_tell = self._ask_and_tell_based_on_loss
76+
else:
77+
raise ValueError('Only strategy="loss_improvements" or'
78+
' strategy="loss" is implemented.')
6579

66-
def _ask_and_tell(self, n):
80+
def _ask_and_tell_based_on_loss_improvements(self, n):
6781
points = []
6882
loss_improvements = []
6983
for _ in range(n):
@@ -84,6 +98,17 @@ def _ask_and_tell(self, n):
8498

8599
return points, loss_improvements
86100

101+
def _ask_and_tell_based_on_loss(self, n):
102+
points = []
103+
loss_improvements = []
104+
for _ in range(n):
105+
losses = self.losses(real=False)
106+
max_ind = np.argmax(losses)
107+
xs, ls = self.learners[max_ind].ask(1)
108+
points.append((max_ind, xs[0]))
109+
loss_improvements.append(ls[0])
110+
return points, loss_improvements
111+
87112
def ask(self, n, tell_pending=True):
88113
"""Chose points for learners."""
89114
if not tell_pending:

0 commit comments

Comments
 (0)