@@ -107,9 +107,10 @@ class LearnerND(BaseLearner):
107107 func: callable
108108 The function to learn. Must take a tuple of N real
109109 parameters and return a real number or an arraylike of length M.
110- bounds : list of 2-tuples
110+ bounds : list of 2-tuples or `scipy.spatial.ConvexHull`
111111 A list ``[(a_1, b_1), (a_2, b_2), ..., (a_n, b_n)]`` containing bounds,
112112 one pair per dimension.
113+ Or a ConvexHull that defines the boundary of the domain.
113114 loss_per_simplex : callable, optional
114115 A function that returns the loss for a simplex.
115116 If not provided, then a default is used, which uses
@@ -150,14 +151,21 @@ class LearnerND(BaseLearner):
150151 """
151152
152153 def __init__ (self , func , bounds , loss_per_simplex = None ):
153- self .ndim = len (bounds )
154154 self ._vdim = None
155155 self .loss_per_simplex = loss_per_simplex or default_loss
156- self .bounds = tuple (tuple (map (float , b )) for b in bounds )
157156 self .data = OrderedDict ()
158157 self .pending_points = set ()
159158
160- self ._bounds_points = list (map (tuple , itertools .product (* bounds )))
159+ if isinstance (bounds , scipy .spatial .ConvexHull ):
160+ hull_points = bounds .points [bounds .vertices ]
161+ self ._bounds_points = sorted (list (map (tuple , hull_points )))
162+ self ._bbox = tuple (zip (hull_points .min (axis = 0 ), hull_points .max (axis = 0 )))
163+ self ._interior = scipy .spatial .Delaunay (self ._bounds_points )
164+ else :
165+ self ._bounds_points = sorted (list (map (tuple , itertools .product (* bounds ))))
166+ self ._bbox = tuple (tuple (map (float , b )) for b in bounds )
167+
168+ self .ndim = len (self ._bbox )
161169
162170 self .function = func
163171 self ._tri = None
@@ -169,7 +177,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
169177 self ._subtriangulations = dict () # simplex → triangulation
170178
171179 # scale to unit
172- self ._transform = np .linalg .inv (np .diag (np .diff (bounds ).flat ))
180+ self ._transform = np .linalg .inv (np .diag (np .diff (self . _bbox ).flat ))
173181
174182 # create a private random number generator with fixed seed
175183 self ._random = random .Random (1 )
@@ -275,7 +283,12 @@ def _simplex_exists(self, simplex):
275283
276284 def inside_bounds (self , point ):
277285 """Check whether a point is inside the bounds."""
278- return all (mn <= p <= mx for p , (mn , mx ) in zip (point , self .bounds ))
286+ if hasattr (self , '_interior' ):
287+ return self ._interior .find_simplex (point , tol = 1e-8 ) >= 0
288+ else :
289+ eps = 1e-8
290+ return all ((mn - eps ) <= p <= (mx + eps ) for p , (mn , mx )
291+ in zip (point , self ._bbox ))
279292
280293 def tell_pending (self , point , * , simplex = None ):
281294 point = tuple (point )
@@ -349,11 +362,13 @@ def _ask_point_without_known_simplices(self):
349362 assert not self ._bounds_available
350363 # pick a random point inside the bounds
351364 # XXX: change this into picking a point based on volume loss
352- a = np .diff (self .bounds ).flat
353- b = np .array (self .bounds )[:, 0 ]
354- r = np .array ([self ._random .random () for _ in range (self .ndim )])
355- p = r * a + b
356- p = tuple (p )
365+ a = np .diff (self ._bbox ).flat
366+ b = np .array (self ._bbox )[:, 0 ]
367+ p = None
368+ while p is None or not self .inside_bounds (p ):
369+ r = np .array ([self ._random .random () for _ in range (self .ndim )])
370+ p = r * a + b
371+ p = tuple (p )
357372
358373 self .tell_pending (p )
359374 return p , np .inf
@@ -489,10 +504,10 @@ def plot(self, n=None, tri_alpha=0):
489504 if self .vdim > 1 :
490505 raise NotImplementedError ('holoviews currently does not support' ,
491506 '3D surface plots in bokeh.' )
492- if len (self .bounds ) != 2 :
507+ if len (self .ndim ) != 2 :
493508 raise NotImplementedError ("Only 2D plots are implemented: You can "
494509 "plot a 2D slice with 'plot_slice'." )
495- x , y = self .bounds
510+ x , y = self ._bbox
496511 lbrt = x [0 ], y [0 ], x [1 ], y [1 ]
497512
498513 if len (self .data ) >= 4 :
@@ -549,7 +564,7 @@ def plot_slice(self, cut_mapping, n=None):
549564 raise NotImplementedError ('multidimensional output not yet'
550565 ' supported by `plot_slice`' )
551566 n = n or 201
552- values = [cut_mapping .get (i , np .linspace (* self .bounds [i ], n ))
567+ values = [cut_mapping .get (i , np .linspace (* self ._bbox [i ], n ))
553568 for i in range (self .ndim )]
554569 ind = next (i for i in range (self .ndim ) if i not in cut_mapping )
555570 x = values [ind ]
@@ -574,9 +589,9 @@ def plot_slice(self, cut_mapping, n=None):
574589 xys = [xs [:, None ], ys [None , :]]
575590 values = [cut_mapping [i ] if i in cut_mapping
576591 else xys .pop (0 ) * (b [1 ] - b [0 ]) + b [0 ]
577- for i , b in enumerate (self .bounds )]
592+ for i , b in enumerate (self ._bbox )]
578593
579- lbrt = [b for i , b in enumerate (self .bounds )
594+ lbrt = [b for i , b in enumerate (self ._bbox )
580595 if i not in cut_mapping ]
581596 lbrt = np .reshape (lbrt , (2 , 2 )).T .flatten ().tolist ()
582597
0 commit comments