2626from adaptive .notebook_integration import ensure_holoviews
2727from adaptive .utils import cache_latest
2828
29+ Point = Tuple [float , float ]
30+
2931
3032@uses_nth_neighbors (0 )
31- def uniform_loss (xs : Tuple [ float , float ], ys : Tuple [ float , float ] ) -> float :
33+ def uniform_loss (xs : Point , ys : Any ) -> float :
3234 """Loss function that samples the domain uniformly.
3335
3436 Works with `~adaptive.Learner1D` only.
@@ -49,8 +51,8 @@ def uniform_loss(xs: Tuple[float, float], ys: Tuple[float, float]) -> float:
4951
5052@uses_nth_neighbors (0 )
5153def default_loss (
52- xs : Tuple [ float , float ] ,
53- ys : Union [Tuple [Iterable [float ], Iterable [float ]], Tuple [ float , float ] ],
54+ xs : Point ,
55+ ys : Union [Tuple [Iterable [float ], Iterable [float ]], Point ],
5456) -> float :
5557 """Calculate loss on a single interval.
5658
@@ -60,8 +62,8 @@ def default_loss(
6062 """
6163 dx = xs [1 ] - xs [0 ]
6264 if isinstance (ys [0 ], collections .abc .Iterable ):
63- dy = [abs (a - b ) for a , b in zip (* ys )]
64- return np .hypot (dx , dy ).max ()
65+ dy_vec = [abs (a - b ) for a , b in zip (* ys )]
66+ return np .hypot (dx , dy_vec ).max ()
6567 else :
6668 dy = ys [1 ] - ys [0 ]
6769 return np .hypot (dx , dy )
@@ -200,7 +202,7 @@ def __init__(
200202 bounds : Tuple [float , float ],
201203 loss_per_interval : Optional [Callable ] = None ,
202204 ) -> None :
203- self .function = function
205+ self .function = function # type: ignore
204206
205207 if hasattr (loss_per_interval , "nth_neighbors" ):
206208 self .nth_neighbors = loss_per_interval .nth_neighbors
@@ -238,7 +240,7 @@ def __init__(
238240
239241 self .bounds = list (bounds )
240242
241- self ._vdim = None
243+ self ._vdim : Optional [ int ] = None
242244
243245 @property
244246 def vdim (self ) -> int :
@@ -565,7 +567,8 @@ def _ask_points_without_adding(self, n: int) -> Any:
565567 # Add bound intervals to quals if bounds were missing.
566568 if len (self .data ) + len (self .pending_points ) == 0 :
567569 # We don't have any points, so return a linspace with 'n' points.
568- return np .linspace (* self .bounds , n ).tolist (), [np .inf ] * n
570+ a , b = self .bounds
571+ return np .linspace (a , b , n ).tolist (), [np .inf ] * n
569572
570573 quals = loss_manager (self ._scale [0 ])
571574 if len (missing_bounds ) > 0 :
@@ -601,7 +604,7 @@ def _ask_points_without_adding(self, n: int) -> Any:
601604 quals [(* xs , n + 1 )] = loss_qual * n / (n + 1 )
602605
603606 points = list (
604- itertools .chain .from_iterable (linspace (* ival , n ) for (* ival , n ) in quals )
607+ itertools .chain .from_iterable (linspace (a , b , n ) for (( a , b ) , n ) in quals )
605608 )
606609
607610 loss_improvements = list (
@@ -665,7 +668,8 @@ def _get_data(self) -> Dict[float, float]:
665668
666669 def _set_data (self , data : Dict [float , float ]) -> None :
667670 if data :
668- self .tell_many (* zip (* data .items ()))
671+ xs , ys = zip (* data .items ())
672+ self .tell_many (xs , ys )
669673
670674 def __getstate__ (self ):
671675 return (
0 commit comments