22import math
33from collections .abc import Iterable
44from copy import deepcopy
5+ from functools import partial
6+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
57
68import numpy as np
79import sortedcollections
810import sortedcontainers
11+ from numpy import float64 , ndarray
12+ from sortedcollections .recipes import ItemSortedDict
13+ from sortedcontainers .sorteddict import SortedDict
914
1015from adaptive .learner .base_learner import BaseLearner , uses_nth_neighbors
1116from adaptive .learner .learnerND import volume
1520
1621
1722@uses_nth_neighbors (0 )
18- def uniform_loss (xs , ys ):
23+ def uniform_loss (
24+ xs : Union [Tuple [float , float ], Tuple [float64 , float64 ]],
25+ ys : Union [Tuple [float , float ], Tuple [float64 , float64 ]],
26+ ) -> Union [float64 , float ]:
1927 """Loss function that samples the domain uniformly.
2028
2129 Works with `~adaptive.Learner1D` only.
@@ -35,7 +43,15 @@ def uniform_loss(xs, ys):
3543
3644
3745@uses_nth_neighbors (0 )
38- def default_loss (xs , ys ):
46+ def default_loss (
47+ xs : Union [
48+ Tuple [float , float ],
49+ Tuple [float64 , float ],
50+ Tuple [float64 , float64 ],
51+ Tuple [float , float64 ],
52+ ],
53+ ys : Union [Tuple [float , float ], Tuple [ndarray , ndarray ], Tuple [float64 , float64 ]],
54+ ) -> float64 :
3955 """Calculate loss on a single interval.
4056
4157 Currently returns the rescaled length of the interval. If one of the
@@ -52,7 +68,7 @@ def default_loss(xs, ys):
5268
5369
5470@uses_nth_neighbors (1 )
55- def triangle_loss (xs , ys ) :
71+ def triangle_loss (xs : Any , ys : Any ) -> Union [ float64 , float ] :
5672 xs = [x for x in xs if x is not None ]
5773 ys = [y for y in ys if y is not None ]
5874
@@ -69,7 +85,9 @@ def triangle_loss(xs, ys):
6985 return sum (vol (pts [i : i + 3 ]) for i in range (N )) / N
7086
7187
72- def curvature_loss_function (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
88+ def curvature_loss_function (
89+ area_factor : int = 1 , euclid_factor : float = 0.02 , horizontal_factor : float = 0.02
90+ ) -> Callable :
7391 # XXX: add a doc-string
7492 @uses_nth_neighbors (1 )
7593 def curvature_loss (xs , ys ):
@@ -88,7 +106,9 @@ def curvature_loss(xs, ys):
88106 return curvature_loss
89107
90108
91- def linspace (x_left , x_right , n ):
109+ def linspace (
110+ x_left : Union [int , float64 , float ], x_right : Union [int , float64 , float ], n : int
111+ ) -> Union [List [float ], List [float64 ]]:
92112 """This is equivalent to
93113 'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
94114 but it is 15-30 times faster for small 'n'."""
@@ -100,7 +120,7 @@ def linspace(x_left, x_right, n):
100120 return [x_left + step * i for i in range (1 , n )]
101121
102122
103- def _get_neighbors_from_list (xs ) :
123+ def _get_neighbors_from_list (xs : ndarray ) -> SortedDict :
104124 xs = np .sort (xs )
105125 xs_left = np .roll (xs , 1 ).tolist ()
106126 xs_right = np .roll (xs , - 1 ).tolist ()
@@ -110,7 +130,9 @@ def _get_neighbors_from_list(xs):
110130 return sortedcontainers .SortedDict (neighbors )
111131
112132
113- def _get_intervals (x , neighbors , nth_neighbors ):
133+ def _get_intervals (
134+ x : Union [int , float64 , float ], neighbors : SortedDict , nth_neighbors : int
135+ ) -> Any :
114136 nn = nth_neighbors
115137 i = neighbors .index (x )
116138 start = max (0 , i - nn - 1 )
@@ -163,7 +185,12 @@ class Learner1D(BaseLearner):
163185 decorator for more information.
164186 """
165187
166- def __init__ (self , function , bounds , loss_per_interval = None ):
188+ def __init__ (
189+ self ,
190+ function : Union [Callable , partial ],
191+ bounds : Union [Tuple [int , int ], Tuple [float , float ], ndarray ],
192+ loss_per_interval : Optional [Callable ] = None ,
193+ ) -> None :
167194 self .function = function
168195
169196 if hasattr (loss_per_interval , "nth_neighbors" ):
@@ -205,7 +232,7 @@ def __init__(self, function, bounds, loss_per_interval=None):
205232 self ._vdim = None
206233
207234 @property
208- def vdim (self ):
235+ def vdim (self ) -> int :
209236 """Length of the output of ``learner.function``.
210237 If the output is unsized (when it's a scalar)
211238 then `vdim = 1`.
@@ -225,35 +252,41 @@ def vdim(self):
225252 return self ._vdim
226253
227254 @property
228- def npoints (self ):
255+ def npoints (self ) -> int :
229256 """Number of evaluated points."""
230257 return len (self .data )
231258
232259 @cache_latest
233- def loss (self , real = True ):
260+ def loss (self , real : bool = True ) -> Union [ int , float64 , float ] :
234261 losses = self .losses if real else self .losses_combined
235262 if not losses :
236263 return np .inf
237264 max_interval , max_loss = losses .peekitem (0 )
238265 return max_loss
239266
240- def _scale_x (self , x ):
267+ def _scale_x (
268+ self , x : Optional [Union [float , int , float64 ]]
269+ ) -> Optional [Union [float , float64 ]]:
241270 if x is None :
242271 return None
243272 return x / self ._scale [0 ]
244273
245- def _scale_y (self , y ):
274+ def _scale_y (
275+ self , y : Optional [Union [int , ndarray , float64 , float ]]
276+ ) -> Optional [Union [float , float64 , ndarray ]]:
246277 if y is None :
247278 return None
248279 y_scale = self ._scale [1 ] or 1
249280 return y / y_scale
250281
251- def _get_point_by_index (self , ind ) :
282+ def _get_point_by_index (self , ind : int ) -> Optional [ Union [ int , float64 , float ]] :
252283 if ind < 0 or ind >= len (self .neighbors ):
253284 return None
254285 return self .neighbors .keys ()[ind ]
255286
256- def _get_loss_in_interval (self , x_left , x_right ):
287+ def _get_loss_in_interval (
288+ self , x_left : Union [int , float64 , float ], x_right : Union [int , float64 , float ]
289+ ) -> Union [int , float64 , float ]:
257290 assert x_left is not None and x_right is not None
258291
259292 if x_right - x_left < self ._dx_eps :
@@ -273,7 +306,9 @@ def _get_loss_in_interval(self, x_left, x_right):
273306 # we need to compute the loss for this interval
274307 return self .loss_per_interval (xs_scaled , ys_scaled )
275308
276- def _update_interpolated_loss_in_interval (self , x_left , x_right ):
309+ def _update_interpolated_loss_in_interval (
310+ self , x_left : Union [int , float64 , float ], x_right : Union [int , float64 , float ]
311+ ) -> None :
277312 if x_left is None or x_right is None :
278313 return
279314
@@ -289,7 +324,7 @@ def _update_interpolated_loss_in_interval(self, x_left, x_right):
289324 self .losses_combined [a , b ] = (b - a ) * loss / dx
290325 a = b
291326
292- def _update_losses (self , x , real = True ):
327+ def _update_losses (self , x : Union [ int , float64 , float ], real : bool = True ) -> None :
293328 """Update all losses that depend on x"""
294329 # When we add a new point x, we should update the losses
295330 # (x_left, x_right) are the "real" neighbors of 'x'.
@@ -332,7 +367,7 @@ def _update_losses(self, x, real=True):
332367 self .losses_combined [x , b ] = float ("inf" )
333368
334369 @staticmethod
335- def _find_neighbors (x , neighbors ) :
370+ def _find_neighbors (x : Union [ int , float64 , float ], neighbors : SortedDict ) -> Any :
336371 if x in neighbors :
337372 return neighbors [x ]
338373 pos = neighbors .bisect_left (x )
@@ -341,14 +376,18 @@ def _find_neighbors(x, neighbors):
341376 x_right = keys [pos ] if pos != len (neighbors ) else None
342377 return x_left , x_right
343378
344- def _update_neighbors (self , x , neighbors ):
379+ def _update_neighbors (
380+ self , x : Union [int , float64 , float ], neighbors : SortedDict
381+ ) -> None :
345382 if x not in neighbors : # The point is new
346383 x_left , x_right = self ._find_neighbors (x , neighbors )
347384 neighbors [x ] = [x_left , x_right ]
348385 neighbors .get (x_left , [None , None ])[1 ] = x
349386 neighbors .get (x_right , [None , None ])[0 ] = x
350387
351- def _update_scale (self , x , y ):
388+ def _update_scale (
389+ self , x : Union [int , float64 , float ], y : Union [float , int , float64 , ndarray ]
390+ ) -> None :
352391 """Update the scale with which the x and y-values are scaled.
353392
354393 For a learner where the function returns a single scalar the scale
@@ -375,7 +414,7 @@ def _update_scale(self, x, y):
375414 self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
376415 self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
377416
378- def tell (self , x , y ) :
417+ def tell (self , x : Union [ int , float64 , float ], y : Any ) -> None :
379418 if x in self .data :
380419 # The point is already evaluated before
381420 return
@@ -410,15 +449,15 @@ def tell(self, x, y):
410449
411450 self ._oldscale = deepcopy (self ._scale )
412451
413- def tell_pending (self , x ) :
452+ def tell_pending (self , x : Union [ int , float64 , float ]) -> None :
414453 if x in self .data :
415454 # The point is already evaluated before
416455 return
417456 self .pending_points .add (x )
418457 self ._update_neighbors (x , self .neighbors_combined )
419458 self ._update_losses (x , real = False )
420459
421- def tell_many (self , xs , ys , * , force = False ):
460+ def tell_many (self , xs : Any , ys : Any , * , force = False ) -> None :
422461 if not force and not (len (xs ) > 0.5 * len (self .data ) and len (xs ) > 2 ):
423462 # Only run this more efficient method if there are
424463 # at least 2 points and the amount of points added are
@@ -486,7 +525,7 @@ def tell_many(self, xs, ys, *, force=False):
486525 # have an inf loss.
487526 self ._update_interpolated_loss_in_interval (* ival )
488527
489- def ask (self , n , tell_pending = True ):
528+ def ask (self , n : int , tell_pending : bool = True ) -> Any :
490529 """Return 'n' points that are expected to maximally reduce the loss."""
491530 points , loss_improvements = self ._ask_points_without_adding (n )
492531
@@ -496,7 +535,7 @@ def ask(self, n, tell_pending=True):
496535
497536 return points , loss_improvements
498537
499- def _ask_points_without_adding (self , n ) :
538+ def _ask_points_without_adding (self , n : int ) -> Any :
500539 """Return 'n' points that are expected to maximally reduce the loss.
501540 Without altering the state of the learner"""
502541 # Find out how to divide the n points over the intervals
@@ -574,7 +613,7 @@ def _ask_points_without_adding(self, n):
574613
575614 return points , loss_improvements
576615
577- def _loss (self , mapping , ival ) :
616+ def _loss (self , mapping : ItemSortedDict , ival : Any ) -> Any :
578617 loss = mapping [ival ]
579618 return finite_loss (ival , loss , self ._scale [0 ])
580619
@@ -613,20 +652,20 @@ def plot(self, *, scatter_or_line="scatter"):
613652
614653 return p .redim (x = dict (range = plot_bounds ))
615654
616- def remove_unfinished (self ):
655+ def remove_unfinished (self ) -> None :
617656 self .pending_points = set ()
618657 self .losses_combined = deepcopy (self .losses )
619658 self .neighbors_combined = deepcopy (self .neighbors )
620659
621- def _get_data (self ):
660+ def _get_data (self ) -> Dict [ Union [ int , float ], float ] :
622661 return self .data
623662
624- def _set_data (self , data ) :
663+ def _set_data (self , data : Dict [ Union [ int , float ], float ]) -> None :
625664 if data :
626665 self .tell_many (* zip (* data .items ()))
627666
628667
629- def loss_manager (x_scale ) :
668+ def loss_manager (x_scale : Union [ int , float64 , float ]) -> ItemSortedDict :
630669 def sort_key (ival , loss ):
631670 loss , ival = finite_loss (ival , loss , x_scale )
632671 return - loss , ival
@@ -635,7 +674,9 @@ def sort_key(ival, loss):
635674 return sorted_dict
636675
637676
638- def finite_loss (ival , loss , x_scale ):
677+ def finite_loss (
678+ ival : Any , loss : Union [int , float64 , float ], x_scale : Union [int , float64 , float ]
679+ ) -> Any :
639680 """Get the socalled finite_loss of an interval in order to be able to
640681 sort intervals that have infinite loss."""
641682 # If the loss is infinite we return the
0 commit comments