1515from ..utils import cache_latest
1616
1717
18- def uniform_loss (interval , scale , function_values ):
18+ def uniform_loss (interval , scale , function_values , neighbors ):
1919 """Loss function that samples the domain uniformly.
2020
2121 Works with `~adaptive.Learner1D` only.
@@ -36,7 +36,7 @@ def uniform_loss(interval, scale, function_values):
3636 return dx
3737
3838
39- def default_loss (interval , scale , function_values ):
39+ def default_loss (interval , scale , function_values , neighbors ):
4040 """Calculate loss on a single interval.
4141
4242 Currently returns the rescaled length of the interval. If one of the
@@ -70,12 +70,9 @@ def _loss_of_multi_interval(xs, ys):
7070 return sum (vol (pts [i :i + 3 ]) for i in range (N )) / N
7171
7272
73- def triangle_loss (interval , neighbours , scale , function_values ):
73+ def triangle_loss (interval , scale , function_values , neighbors ):
7474 x_left , x_right = interval
75- neighbour_left , neighbour_right = neighbours
76- xs = [neighbour_left , x_left , x_right , neighbour_right ]
77- # The neighbours could be None if we are at the boundary, in that case we
78- # have to filter this out
75+ xs = [neighbors [x_left ][0 ], x_left , x_right , neighbors [x_right ][1 ]]
7976 xs = [x for x in xs if x is not None ]
8077
8178 if len (xs ) <= 2 :
@@ -88,9 +85,9 @@ def triangle_loss(interval, neighbours, scale, function_values):
8885
8986
9087def get_curvature_loss (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
91- def curvature_loss (interval , neighbours , scale , function_values ):
92- triangle_loss_ = triangle_loss (interval , neighbours , scale , function_values )
93- default_loss_ = default_loss (interval , scale , function_values )
88+ def curvature_loss (interval , scale , function_values , neighbors ):
89+ triangle_loss_ = triangle_loss (interval , scale , function_values , neighbors )
90+ default_loss_ = default_loss (interval , scale , function_values , neighbors )
9491 dx = (interval [1 ] - interval [0 ]) / scale [0 ]
9592 return (area_factor * (triangle_loss_ ** 0.5 )
9693 + euclid_factor * default_loss_
@@ -121,6 +118,15 @@ def _get_neighbors_from_list(xs):
121118 return sortedcontainers .SortedDict (neighbors )
122119
123120
121+ def _get_intervals (x , neighbors , nn_neighbors ):
122+ nn = nn_neighbors
123+ i = neighbors .index (x )
124+ start = max (0 , i - nn - 1 )
125+ end = min (len (neighbors ), i + nn + 2 )
126+ points = neighbors .keys ()[start :end ]
127+ return list (zip (points , points [1 :]))
128+
129+
124130class Learner1D (BaseLearner ):
125131 """Learns and predicts a function 'f:ℝ → ℝ^N'.
126132
@@ -135,6 +141,10 @@ class Learner1D(BaseLearner):
135141 A function that returns the loss for a single interval of the domain.
136142 If not provided, then a default is used, which uses the scaled distance
137143 in the x-y plane as the loss. See the notes for more details.
144+ nn_neighbors : int, default: 0
145+ The number of neighboring intervals that the loss function
146+ takes into account. If ``loss_per_interval`` doesn't use the neighbors
147+ at all, then it should be 0.
138148
139149 Attributes
140150 ----------
@@ -145,9 +155,9 @@ class Learner1D(BaseLearner):
145155
146156 Notes
147157 -----
148- `loss_per_interval` takes 3 parameters: ``interval``, ``scale``, and
149- ``function_values ``, and returns a scalar; the loss over the interval.
150-
158+ `loss_per_interval` takes 4 parameters: ``interval``, ``scale``,
159+ ``data ``, and ``neighbors``, and returns a scalar; the loss over
160+ the interval.
151161 interval : (float, float)
152162 The bounds of the interval.
153163 scale : (float, float)
@@ -156,16 +166,18 @@ class Learner1D(BaseLearner):
156166 function_values : dict(float → float)
157167 A map containing evaluated function values. It is guaranteed
158168 to have values for both of the points in 'interval'.
169+ neighbors : dict(float → (float, float))
170+ A map containing points as keys to its neighbors as a tuple.
159171 """
160172
161- def __init__ (self , function , bounds , loss_per_interval = None , loss_depends_on_neighbours = False ):
173+ def __init__ (self , function , bounds , loss_per_interval = None , nn_neighbors = 0 ):
162174 self .function = function
163- self ._loss_depends_on_neighbours = loss_depends_on_neighbours
175+ self .nn_neighbors = nn_neighbors
164176
165- if loss_depends_on_neighbours :
166- self .loss_per_interval = loss_per_interval or get_curvature_loss ()
167- else :
177+ if nn_neighbors == 0 :
168178 self .loss_per_interval = loss_per_interval or default_loss
179+ else :
180+ self .loss_per_interval = loss_per_interval or get_curvature_loss ()
169181
170182 # A dict storing the loss function for each interval x_n.
171183 self .losses = {}
@@ -230,15 +242,8 @@ def _get_loss_in_interval(self, x_left, x_right):
230242 return 0
231243
232244 # we need to compute the loss for this interval
233- interval = (x_left , x_right )
234- if self ._loss_depends_on_neighbours :
235- neighbour_left = self .neighbors .get (x_left , (None , None ))[0 ]
236- neighbour_right = self .neighbors .get (x_right , (None , None ))[1 ]
237- neighbours = neighbour_left , neighbour_right
238- return self .loss_per_interval (interval , neighbours ,
239- self ._scale , self .data )
240- else :
241- return self .loss_per_interval (interval , self ._scale , self .data )
245+ return self .loss_per_interval (
246+ (x_left , x_right ), self ._scale , self .data , self .neighbors )
242247
243248
244249 def _update_interpolated_loss_in_interval (self , x_left , x_right ):
@@ -271,17 +276,11 @@ def _update_losses(self, x, real=True):
271276
272277 if real :
273278 # We need to update all interpolated losses in the interval
274- # (x_left, x) and (x, x_right). Since the addition of the point
275- # 'x' could change their loss.
276- self ._update_interpolated_loss_in_interval (x_left , x )
277- self ._update_interpolated_loss_in_interval (x , x_right )
278-
279- # if the loss depends on the neighbors we should also update those losses
280- if self ._loss_depends_on_neighbours :
281- neighbour_left = self .neighbors .get (x_left , (None , None ))[0 ]
282- neighbour_right = self .neighbors .get (x_right , (None , None ))[1 ]
283- self ._update_interpolated_loss_in_interval (neighbour_left , x_left )
284- self ._update_interpolated_loss_in_interval (x_right , neighbour_right )
279+ # (x_left, x), (x, x_right) and the nn_neighbors nearest
280+ # neighboring intervals. Since the addition of the
281+ # point 'x' could change their loss.
282+ for ival in _get_intervals (x , self .neighbors , self .nn_neighbors ):
283+ self ._update_interpolated_loss_in_interval (* ival )
285284
286285 # Since 'x' is in between (x_left, x_right),
287286 # we get rid of the interval.
@@ -427,10 +426,8 @@ def tell_many(self, xs, ys, *, force=False):
427426
428427 # The the losses for the "real" intervals.
429428 self .losses = {}
430- for x_left , x_right in intervals :
431- self .losses [x_left , x_right ] = (
432- self ._get_loss_in_interval (x_left , x_right )
433- if x_right - x_left >= self ._dx_eps else 0 )
429+ for ival in intervals :
430+ self .losses [ival ] = self ._get_loss_in_interval (* ival )
434431
435432 # List with "real" intervals that have interpolated intervals inside
436433 to_interpolate = []
0 commit comments