1212import scipy .spatial
1313from sortedcontainers import SortedKeyList
1414
15- from adaptive .learner .base_learner import BaseLearner
15+ from adaptive .learner .base_learner import BaseLearner , uses_nth_neighbors
1616from adaptive .notebook_integration import ensure_holoviews , ensure_plotly
1717from adaptive .learner .triangulation import (
1818 Triangulation , point_in_simplex , circumsphere ,
1919 simplex_volume_in_embedding , fast_det )
2020from adaptive .utils import restore , cache_latest
2121
2222
23+ def to_list (inp ):
24+ if isinstance (inp , Iterable ):
25+ return list (inp )
26+ return [inp ]
27+
28+
2329def volume (simplex , ys = None ):
2430 # Notice the parameter ys is there so you can use this volume method as
2531 # as loss function
@@ -60,6 +66,71 @@ def default_loss(simplex, ys):
6066 return simplex_volume_in_embedding (pts )
6167
6268
69+ @uses_nth_neighbors (1 )
70+ def triangle_loss (simplex , values , neighbors , neighbor_values ):
71+ """
72+ Computes the average of the volumes of the simplex combined with each
73+ neighbouring point.
74+
75+ Parameters
76+ ----------
77+ simplex : list of tuples
78+ Each entry is one point of the simplex.
79+ values : list of values
80+ The function values of each of the simplex points.
81+ neighbors : list of tuples
82+ The neighboring points of the simplex, ordered such that simplex[0]
83+ exacly opposes neighbors[0], etc.
84+ neighbor_values : list of values
85+ The function values for each of the neighboring points.
86+
87+ Returns
88+ -------
89+ loss : float
90+ """
91+
92+ neighbors = [n for n in neighbors if n is not None ]
93+ neighbor_values = [v for v in neighbor_values if v is not None ]
94+ if len (neighbors ) == 0 :
95+ return 0
96+
97+ s = [(* x , * to_list (y )) for x , y in zip (simplex , values )]
98+ n = [(* x , * to_list (y )) for x , y in zip (neighbors , neighbor_values )]
99+
100+ return sum (simplex_volume_in_embedding ([* s , neighbor ])
101+ for neighbor in n ) / len (neighbors )
102+
103+
104+ def curvature_loss_function (exploration = 0.05 ):
105+ # XXX: add doc-string!
106+ @uses_nth_neighbors (1 )
107+ def curvature_loss (simplex , values , neighbors , neighbor_values ):
108+ """Compute the curvature loss of a simplex.
109+
110+ Parameters
111+ ----------
112+ simplex : list of tuples
113+ Each entry is one point of the simplex.
114+ values : list of values
115+ The function values of each of the simplex points.
116+ neighbors : list of tuples
117+ The neighboring points of the simplex, ordered such that simplex[0]
118+ exacly opposes neighbors[0], etc.
119+ neighbor_values : list of values
120+ The function values for each of the neighboring points.
121+
122+ Returns
123+ -------
124+ loss : float
125+ """
126+ dim = len (simplex [0 ]) # the number of coordinates
127+ loss_input_volume = volume (simplex )
128+
129+ loss_curvature = triangle_loss (simplex , values , neighbors , neighbor_values )
130+ return (loss_curvature + exploration * loss_input_volume ** ((2 + dim ) / dim )) ** (1 / (2 + dim ))
131+ return curvature_loss
132+
133+
63134def choose_point_in_simplex (simplex , transform = None ):
64135 """Choose a new point in inside a simplex.
65136
@@ -70,9 +141,10 @@ def choose_point_in_simplex(simplex, transform=None):
70141 Parameters
71142 ----------
72143 simplex : numpy array
73- The coordinates of a triangle with shape (N+1, N)
144+ The coordinates of a triangle with shape (N+1, N).
74145 transform : N*N matrix
75- The multiplication to apply to the simplex before choosing the new point
146+ The multiplication to apply to the simplex before choosing
147+ the new point.
76148
77149 Returns
78150 -------
@@ -164,6 +236,17 @@ class LearnerND(BaseLearner):
164236 def __init__ (self , func , bounds , loss_per_simplex = None ):
165237 self ._vdim = None
166238 self .loss_per_simplex = loss_per_simplex or default_loss
239+
240+ if hasattr (self .loss_per_simplex , 'nth_neighbors' ):
241+ if self .loss_per_simplex .nth_neighbors > 1 :
242+ raise NotImplementedError ('The provided loss function wants '
243+ 'next-nearest neighboring simplices for the loss computation, '
244+ 'this feature is not yet implemented, either use '
245+ 'nth_neightbors = 0 or 1' )
246+ self .nth_neighbors = self .loss_per_simplex .nth_neighbors
247+ else :
248+ self .nth_neighbors = 0
249+
167250 self .data = OrderedDict ()
168251 self .pending_points = set ()
169252
@@ -252,14 +335,15 @@ def tri(self):
252335
253336 try :
254337 self ._tri = Triangulation (self .points )
255- self ._update_losses (set (), self ._tri .simplices )
256- return self ._tri
257338 except ValueError :
258339 # A ValueError is raised if we do not have enough points or
259340 # the provided points are coplanar, so we need more points to
260341 # create a valid triangulation
261342 return None
262343
344+ self ._update_losses (set (), self ._tri .simplices )
345+ return self ._tri
346+
263347 @property
264348 def values (self ):
265349 """Get the values from `data` as a numpy array."""
@@ -326,10 +410,10 @@ def tell_pending(self, point, *, simplex=None):
326410
327411 simplex = tuple (simplex )
328412 simplices = [self .tri .vertex_to_simplices [i ] for i in simplex ]
329- neighbours = set .union (* simplices )
413+ neighbors = set .union (* simplices )
330414 # Neighbours also includes the simplex itself
331415
332- for simpl in neighbours :
416+ for simpl in neighbors :
333417 _ , to_add = self ._try_adding_pending_point_to_simplex (point , simpl )
334418 if to_add is None :
335419 continue
@@ -394,6 +478,7 @@ def _pop_highest_existing_simplex(self):
394478 # find the simplex with the highest loss, we do need to check that the
395479 # simplex hasn't been deleted yet
396480 while len (self ._simplex_queue ):
481+ # XXX: Need to add check that the loss is the most recent computed loss
397482 loss , simplex , subsimplex = self ._simplex_queue .pop (0 )
398483 if (subsimplex is None
399484 and simplex in self .tri .simplices
@@ -449,6 +534,35 @@ def _ask(self):
449534
450535 return self ._ask_best_point () # O(log N)
451536
537+ def _compute_loss (self , simplex ):
538+ # get the loss
539+ vertices = self .tri .get_vertices (simplex )
540+ values = [self .data [tuple (v )] for v in vertices ]
541+
542+ # scale them to a cube with sides 1
543+ vertices = vertices @ self ._transform
544+ values = self ._output_multiplier * np .array (values )
545+
546+ if self .nth_neighbors == 0 :
547+ # compute the loss on the scaled simplex
548+ return float (self .loss_per_simplex (vertices , values ))
549+
550+ # We do need the neighbors
551+ neighbors = self .tri .get_opposing_vertices (simplex )
552+
553+ neighbor_points = self .tri .get_vertices (neighbors )
554+ neighbor_values = [self .data .get (x , None ) for x in neighbor_points ]
555+
556+ for i , point in enumerate (neighbor_points ):
557+ if point is not None :
558+ neighbor_points [i ] = point @ self ._transform
559+
560+ for i , value in enumerate (neighbor_values ):
561+ if value is not None :
562+ neighbor_values [i ] = self ._output_multiplier * value
563+
564+ return float (self .loss_per_simplex (vertices , values , neighbor_points , neighbor_values ))
565+
452566 def _update_losses (self , to_delete : set , to_add : set ):
453567 # XXX: add the points outside the triangulation to this as well
454568 pending_points_unbound = set ()
@@ -461,7 +575,6 @@ def _update_losses(self, to_delete: set, to_add: set):
461575
462576 pending_points_unbound = set (p for p in pending_points_unbound
463577 if p not in self .data )
464-
465578 for simplex in to_add :
466579 loss = self ._compute_loss (simplex )
467580 self ._losses [simplex ] = loss
@@ -476,17 +589,20 @@ def _update_losses(self, to_delete: set, to_add: set):
476589 self ._update_subsimplex_losses (
477590 simplex , self ._subtriangulations [simplex ].simplices )
478591
479- def _compute_loss (self , simplex ):
480- # get the loss
481- vertices = self .tri .get_vertices (simplex )
482- values = [self .data [tuple (v )] for v in vertices ]
592+ if self .nth_neighbors :
593+ points_of_added_simplices = set .union (* [set (s ) for s in to_add ])
594+ neighbors = self .tri .get_simplices_attached_to_points (
595+ points_of_added_simplices ) - to_add
596+ for simplex in neighbors :
597+ loss = self ._compute_loss (simplex )
598+ self ._losses [simplex ] = loss
483599
484- # scale them to a cube with sides 1
485- vertices = vertices @ self ._transform
486- values = self . _output_multiplier * np . array ( values )
600+ if simplex not in self . _subtriangulations :
601+ self ._simplex_queue . add (( loss , simplex , None ))
602+ continue
487603
488- # compute the loss on the scaled simplex
489- return float ( self . loss_per_simplex ( vertices , values ) )
604+ self . _update_subsimplex_losses (
605+ simplex , self . _subtriangulations [ simplex ]. simplices )
490606
491607 def _recompute_all_losses (self ):
492608 """Recompute all losses and pending losses."""
0 commit comments