33import heapq
44import itertools
55import math
6+ from collections import Iterable
67
78import numpy as np
89import sortedcontainers
910
1011from .base_learner import BaseLearner
12+ from .learnerND import volume
13+ from .triangulation import simplex_volume_in_embedding
1114from ..notebook_integration import ensure_holoviews
1215from ..utils import cache_latest
1316
@@ -56,6 +59,45 @@ def default_loss(interval, scale, function_values):
5659 return loss
5760
5861
62+ def _loss_of_multi_interval (xs , ys ):
63+ N = len (xs ) - 2
64+ if isinstance (ys [0 ], Iterable ):
65+ pts = [(x , * y ) for x , y in zip (xs , ys )]
66+ vol = simplex_volume_in_embedding
67+ else :
68+ pts = [(x , y ) for x , y in zip (xs , ys )]
69+ vol = volume
70+ return sum (vol (pts [i :i + 3 ]) for i in range (N )) / N
71+
72+
73+ def triangle_loss (interval , neighbours , scale , function_values ):
74+ x_left , x_right = interval
75+ neighbour_left , neighbour_right = neighbours
76+ xs = [neighbour_left , x_left , x_right , neighbour_right ]
77+ # The neighbours could be None if we are at the boundary, in that case we
78+ # have to filter this out
79+ xs = [x for x in xs if x is not None ]
80+
81+ if len (xs ) <= 2 :
82+ return (x_right - x_left ) / scale [0 ]
83+ else :
84+ y_scale = scale [1 ] or 1
85+ ys_scaled = [function_values [x ] / y_scale for x in xs ]
86+ xs_scaled = [x / scale [0 ] for x in xs ]
87+ return _loss_of_multi_interval (xs_scaled , ys_scaled )
88+
89+
90+ def get_curvature_loss (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
91+ def curvature_loss (interval , neighbours , scale , function_values ):
92+ triangle_loss_ = triangle_loss (interval , neighbours , scale , function_values )
93+ default_loss_ = default_loss (interval , scale , function_values )
94+ dx = (interval [1 ] - interval [0 ]) / scale [0 ]
95+ return (area_factor * (triangle_loss_ ** 0.5 )
96+ + euclid_factor * default_loss_
97+ + horizontal_factor * dx )
98+ return curvature_loss
99+
100+
59101def linspace (x_left , x_right , n ):
60102 """This is equivalent to
61103 'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
@@ -116,9 +158,14 @@ class Learner1D(BaseLearner):
116158 to have values for both of the points in 'interval'.
117159 """
118160
119- def __init__ (self , function , bounds , loss_per_interval = None ):
161+ def __init__ (self , function , bounds , loss_per_interval = None , loss_depends_on_neighbours = False ):
120162 self .function = function
121- self .loss_per_interval = loss_per_interval or default_loss
163+ self ._loss_depends_on_neighbours = loss_depends_on_neighbours
164+
165+ if loss_depends_on_neighbours :
166+ self .loss_per_interval = loss_per_interval or get_curvature_loss ()
167+ else :
168+ self .loss_per_interval = loss_per_interval or default_loss
122169
123170 # A dict storing the loss function for each interval x_n.
124171 self .losses = {}
@@ -176,25 +223,42 @@ def loss(self, real=True):
176223 losses = self .losses if real else self .losses_combined
177224 return max (losses .values ()) if len (losses ) > 0 else float ('inf' )
178225
226+ def _get_loss_in_interval (self , x_left , x_right ):
227+ assert x_left is not None and x_right is not None
228+
229+ if x_right - x_left < self ._dx_eps :
230+ return 0
231+
232+ # we need to compute the loss for this interval
233+ interval = (x_left , x_right )
234+ if self ._loss_depends_on_neighbours :
235+ neighbour_left = self .neighbors .get (x_left , (None , None ))[0 ]
236+ neighbour_right = self .neighbors .get (x_right , (None , None ))[1 ]
237+ neighbours = neighbour_left , neighbour_right
238+ return self .loss_per_interval (interval , neighbours ,
239+ self ._scale , self .data )
240+ else :
241+ return self .loss_per_interval (interval , self ._scale , self .data )
242+
243+
179244 def _update_interpolated_loss_in_interval (self , x_left , x_right ):
180- if x_left is not None and x_right is not None :
181- dx = x_right - x_left
182- if dx < self ._dx_eps :
183- loss = 0
184- else :
185- loss = self .loss_per_interval ((x_left , x_right ),
186- self ._scale , self .data )
187- self .losses [x_left , x_right ] = loss
188-
189- # Iterate over all interpolated intervals in between
190- # x_left and x_right and set the newly interpolated loss.
191- a , b = x_left , None
192- while b != x_right :
193- b = self .neighbors_combined [a ][1 ]
194- self .losses_combined [a , b ] = (b - a ) * loss / dx
195- a = b
245+ if x_left is None or x_right is None :
246+ return
247+
248+ loss = self ._get_loss_in_interval (x_left , x_right )
249+ self .losses [x_left , x_right ] = loss
250+
251+ # Iterate over all interpolated intervals in between
252+ # x_left and x_right and set the newly interpolated loss.
253+ a , b = x_left , None
254+ dx = x_right - x_left
255+ while b != x_right :
256+ b = self .neighbors_combined [a ][1 ]
257+ self .losses_combined [a , b ] = (b - a ) * loss / dx
258+ a = b
196259
197260 def _update_losses (self , x , real = True ):
261+ """Update all losses that depend on x"""
198262 # When we add a new point x, we should update the losses
199263 # (x_left, x_right) are the "real" neighbors of 'x'.
200264 x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -212,6 +276,13 @@ def _update_losses(self, x, real=True):
212276 self ._update_interpolated_loss_in_interval (x_left , x )
213277 self ._update_interpolated_loss_in_interval (x , x_right )
214278
279+ # if the loss depends on the neighbors we should also update those losses
280+ if self ._loss_depends_on_neighbours :
281+ neighbour_left = self .neighbors .get (x_left , (None , None ))[0 ]
282+ neighbour_right = self .neighbors .get (x_right , (None , None ))[1 ]
283+ self ._update_interpolated_loss_in_interval (neighbour_left , x_left )
284+ self ._update_interpolated_loss_in_interval (x_right , neighbour_right )
285+
215286 # Since 'x' is in between (x_left, x_right),
216287 # we get rid of the interval.
217288 self .losses .pop ((x_left , x_right ), None )
@@ -358,7 +429,7 @@ def tell_many(self, xs, ys, *, force=False):
358429 self .losses = {}
359430 for x_left , x_right in intervals :
360431 self .losses [x_left , x_right ] = (
361- self .loss_per_interval (( x_left , x_right ), self . _scale , self . data )
432+ self ._get_loss_in_interval ( x_left , x_right )
362433 if x_right - x_left >= self ._dx_eps else 0 )
363434
364435 # List with "real" intervals that have interpolated intervals inside
0 commit comments