Skip to content

Commit 953ff84

Browse files
committed
Merge branch 'cache_loss' into 'master'
Cache loss and display it in the live_info widget See merge request qt/adaptive!117
2 parents b5b81ac + d5774ff commit 953ff84

11 files changed

Lines changed: 38 additions & 18 deletions

adaptive/learner/average_learner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
import numpy as np
66

7-
from ..notebook_integration import ensure_holoviews
87
from .base_learner import BaseLearner
8+
from ..notebook_integration import ensure_holoviews
9+
from ..utils import cache_latest
910

1011

1112
class AverageLearner(BaseLearner):
@@ -90,6 +91,7 @@ def std(self):
9091
return np.inf
9192
return sqrt((self.sum_f_sq - n * self.mean**2) / (n - 1))
9293

94+
@cache_latest
9395
def loss(self, real=True, *, n=None):
9496
if n is None:
9597
n = self.npoints if real else self.n_requested

adaptive/learner/balancing_learner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .base_learner import BaseLearner
88
from ..notebook_integration import ensure_holoviews
9-
from ..utils import restore, named_product
9+
from ..utils import cache_latest, named_product, restore
1010

1111

1212
def dispatch(child_functions, arg):
@@ -116,6 +116,7 @@ def losses(self, real=True):
116116

117117
return losses
118118

119+
@cache_latest
119120
def loss(self, real=True):
120121
losses = self.losses(real)
121122
return max(losses)

adaptive/learner/data_saver.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# -*- coding: utf-8 -*-
2-
32
from collections import OrderedDict
43
import functools
54

adaptive/learner/integrator_learner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
ndiv_max, min_sep, eps, xi, V_inv,
1616
Vcond, alpha, gamma)
1717
from ..notebook_integration import ensure_holoviews
18-
from ..utils import restore
18+
from ..utils import cache_latest, restore
1919

2020

2121
def _downdate(c, nans, depth):
@@ -514,6 +514,7 @@ def done(self):
514514
or (err - err_excess < abs(igral) * self.tol < err_excess)
515515
or not self.ivals)
516516

517+
@cache_latest
517518
def loss(self, real=True):
518519
return abs(abs(self.igral) * self.tol - self.err)
519520

adaptive/learner/learner1D.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import numpy as np
88
import sortedcontainers
99

10-
from ..notebook_integration import ensure_holoviews
1110
from .base_learner import BaseLearner
11+
from ..notebook_integration import ensure_holoviews
12+
from ..utils import cache_latest
1213

1314

1415
def uniform_loss(interval, scale, function_values):
@@ -156,6 +157,7 @@ def vdim(self):
156157
def npoints(self):
157158
return len(self.data)
158159

160+
@cache_latest
159161
def loss(self, real=True):
160162
losses = self.losses if real else self.losses_combined
161163
return max(losses.values()) if len(losses) > 0 else float('inf')

adaptive/learner/learner2D.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import numpy as np
88
from scipy import interpolate
99

10-
from ..notebook_integration import ensure_holoviews
1110
from .base_learner import BaseLearner
11+
from ..notebook_integration import ensure_holoviews
12+
from ..utils import cache_latest
1213

1314

1415
# Learner2D and helper functions.
@@ -267,7 +268,6 @@ def __init__(self, function, bounds, loss_per_triangle=None):
267268
self._stack.update({p: np.inf for p in self._bounds_points})
268269
self.function = function
269270
self._ip = self._ip_combined = None
270-
self._loss = np.inf
271271

272272
self.stack_size = 10
273273

@@ -438,13 +438,13 @@ def ask(self, n, tell_pending=True):
438438

439439
return points[:n], loss_improvements[:n]
440440

441+
@cache_latest
441442
def loss(self, real=True):
442443
if not self.bounds_are_done:
443444
return np.inf
444445
ip = self.ip() if real else self.ip_combined()
445446
losses = self.loss_per_triangle(ip)
446-
self._loss = losses.max()
447-
return self._loss
447+
return losses.max()
448448

449449
def remove_unfinished(self):
450450
self.pending_points = set()

adaptive/learner/learnerND.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ..notebook_integration import ensure_holoviews
1414
from .triangulation import (Triangulation, point_in_simplex,
1515
circumsphere, simplex_volume_in_embedding)
16-
from ..utils import restore
16+
from ..utils import restore, cache_latest
1717

1818

1919
def volume(simplex, ys=None):
@@ -452,6 +452,7 @@ def losses(self):
452452

453453
return self._losses
454454

455+
@cache_latest
455456
def loss(self, real=True):
456457
losses = self.losses() # XXX: compute pending loss if real == False
457458
return max(losses.values()) if losses else float('inf')

adaptive/learner/skopt_learner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# -*- coding: utf-8 -*-
2-
32
import numpy as np
3+
from skopt import Optimizer
44

5-
from ..notebook_integration import ensure_holoviews
65
from .base_learner import BaseLearner
7-
8-
from skopt import Optimizer
6+
from ..notebook_integration import ensure_holoviews
7+
from ..utils import restore, cache_latest
98

109

1110
class SKOptLearner(Optimizer, BaseLearner):
@@ -38,6 +37,7 @@ def tell_pending(self, x):
3837
def remove_unfinished(self):
3938
pass
4039

40+
@cache_latest
4141
def loss(self, real=True):
4242
if not self.models:
4343
return np.inf

adaptive/notebook_integration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ def _info_html(runner):
174174
with suppress(Exception):
175175
info.append(('# of points', runner.learner.npoints))
176176

177+
with suppress(Exception):
178+
info.append(('latest loss', f'{runner.learner._cache["loss"]:.3f}'))
179+
177180
template = '<dt>{}</dt><dd>{}</dd>'
178181
table = '\n'.join(template.format(k, v) for k, v in info)
179182

adaptive/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from contextlib import contextmanager
3+
from functools import wraps
34
from itertools import product
45
import time
56

@@ -24,3 +25,16 @@ def restore(*learners):
2425
finally:
2526
for state, learner in zip(states, learners):
2627
learner.__setstate__(state)
28+
29+
30+
def cache_latest(f):
31+
"""Cache the latest return value of the function and add it
32+
as 'self._cache[f.__name__]'."""
33+
@wraps(f)
34+
def wrapper(*args, **kwargs):
35+
self = args[0]
36+
if not hasattr(self, '_cache'):
37+
self._cache = {}
38+
self._cache[f.__name__] = f(*args, **kwargs)
39+
return self._cache[f.__name__]
40+
return wrapper

0 commit comments

Comments
 (0)