1+ from __future__ import annotations
2+
13import collections .abc
24import itertools
35import math
46from copy import copy , deepcopy
5- from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple , Union
7+ from typing import Any , Callable , Dict , List , Sequence , Tuple , Union
68
79import cloudpickle
810import numpy as np
@@ -170,7 +172,7 @@ def curvature_loss(xs: XsType1, ys: YsType1) -> Float:
170172 return curvature_loss
171173
172174
173- def linspace (x_left : Real , x_right : Real , n : Int ) -> List [Float ]:
175+ def linspace (x_left : Real , x_right : Real , n : Int ) -> list [Float ]:
174176 """This is equivalent to
175177 'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
176178 but it is 15-30 times faster for small 'n'."""
@@ -194,7 +196,7 @@ def _get_neighbors_from_array(xs: np.ndarray) -> NeighborsType:
194196
195197def _get_intervals (
196198 x : float , neighbors : NeighborsType , nth_neighbors : int
197- ) -> List [ Tuple [float , float ]]:
199+ ) -> list [ tuple [float , float ]]:
198200 nn = nth_neighbors
199201 i = neighbors .index (x )
200202 start = max (0 , i - nn - 1 )
@@ -249,9 +251,9 @@ class Learner1D(BaseLearner):
249251
250252 def __init__ (
251253 self ,
252- function : Callable [[Real ], Union [ Float , np .ndarray ] ],
253- bounds : Tuple [Real , Real ],
254- loss_per_interval : Optional [ Callable [[XsTypeN , YsTypeN ], Float ]] = None ,
254+ function : Callable [[Real ], Float | np .ndarray ],
255+ bounds : tuple [Real , Real ],
256+ loss_per_interval : Callable [[XsTypeN , YsTypeN ], Float ] | None = None ,
255257 ):
256258 self .function = function # type: ignore
257259
@@ -267,8 +269,8 @@ def __init__(
267269 # the learners behavior in the tests.
268270 self ._recompute_losses_factor = 2
269271
270- self .data : Dict [Real , Real ] = {}
271- self .pending_points : Set [Real ] = set ()
272+ self .data : dict [Real , Real ] = {}
273+ self .pending_points : set [Real ] = set ()
272274
273275 # A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
274276 # properties.
@@ -292,7 +294,7 @@ def __init__(
292294 self .bounds = list (bounds )
293295 self .__missing_bounds = set (self .bounds ) # cache of missing bounds
294296
295- self ._vdim : Optional [ int ] = None
297+ self ._vdim : int | None = None
296298
297299 @property
298300 def vdim (self ) -> int :
@@ -334,20 +336,18 @@ def loss(self, real: bool = True) -> float:
334336 max_interval , max_loss = losses .peekitem (0 )
335337 return max_loss
336338
337- def _scale_x (self , x : Optional [ Float ] ) -> Optional [ Float ] :
339+ def _scale_x (self , x : Float | None ) -> Float | None :
338340 if x is None :
339341 return None
340342 return x / self ._scale [0 ]
341343
342- def _scale_y (
343- self , y : Union [Float , np .ndarray , None ]
344- ) -> Union [Float , np .ndarray , None ]:
344+ def _scale_y (self , y : Float | np .ndarray | None ) -> Float | np .ndarray | None :
345345 if y is None :
346346 return None
347347 y_scale = self ._scale [1 ] or 1
348348 return y / y_scale
349349
350- def _get_point_by_index (self , ind : int ) -> Optional [ float ] :
350+ def _get_point_by_index (self , ind : int ) -> float | None :
351351 if ind < 0 or ind >= len (self .neighbors ):
352352 return None
353353 return self .neighbors .keys ()[ind ]
@@ -449,7 +449,7 @@ def _update_neighbors(self, x: float, neighbors: NeighborsType) -> None:
449449 neighbors .get (x_left , [None , None ])[1 ] = x
450450 neighbors .get (x_right , [None , None ])[0 ] = x
451451
452- def _update_scale (self , x : float , y : Union [ Float , np .ndarray ] ) -> None :
452+ def _update_scale (self , x : float , y : Float | np .ndarray ) -> None :
453453 """Update the scale with which the x and y-values are scaled.
454454
455455 For a learner where the function returns a single scalar the scale
@@ -476,7 +476,7 @@ def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
476476 self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
477477 self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
478478
479- def tell (self , x : float , y : Union [ Float , Sequence [Float ], np .ndarray ] ) -> None :
479+ def tell (self , x : float , y : Float | Sequence [Float ] | np .ndarray ) -> None :
480480 if x in self .data :
481481 # The point is already evaluated before
482482 return
@@ -522,13 +522,9 @@ def tell_pending(self, x: float) -> None:
522522 def tell_many (
523523 self ,
524524 xs : Sequence [Float ],
525- ys : Union [
526- Sequence [Float ],
527- Sequence [Sequence [Float ]],
528- Sequence [np .ndarray ],
529- ],
525+ ys : (Sequence [Float ] | Sequence [Sequence [Float ]] | Sequence [np .ndarray ]),
530526 * ,
531- force : bool = False
527+ force : bool = False ,
532528 ) -> None :
533529 if not force and not (len (xs ) > 0.5 * len (self .data ) and len (xs ) > 2 ):
534530 # Only run this more efficient method if there are
@@ -597,7 +593,7 @@ def tell_many(
597593 # have an inf loss.
598594 self ._update_interpolated_loss_in_interval (* ival )
599595
600- def ask (self , n : int , tell_pending : bool = True ) -> Tuple [ List [float ], List [float ]]:
596+ def ask (self , n : int , tell_pending : bool = True ) -> tuple [ list [float ], list [float ]]:
601597 """Return 'n' points that are expected to maximally reduce the loss."""
602598 points , loss_improvements = self ._ask_points_without_adding (n )
603599
@@ -607,7 +603,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[floa
607603
608604 return points , loss_improvements
609605
610- def _missing_bounds (self ) -> List [Real ]:
606+ def _missing_bounds (self ) -> list [Real ]:
611607 missing_bounds = []
612608 for b in copy (self .__missing_bounds ):
613609 if b in self .data :
@@ -616,7 +612,7 @@ def _missing_bounds(self) -> List[Real]:
616612 missing_bounds .append (b )
617613 return sorted (missing_bounds )
618614
619- def _ask_points_without_adding (self , n : int ) -> Tuple [ List [float ], List [float ]]:
615+ def _ask_points_without_adding (self , n : int ) -> tuple [ list [float ], list [float ]]:
620616 """Return 'n' points that are expected to maximally reduce the loss.
621617 Without altering the state of the learner"""
622618 # Find out how to divide the n points over the intervals
@@ -691,8 +687,8 @@ def _ask_points_without_adding(self, n: int) -> Tuple[List[float], List[float]]:
691687 return points , loss_improvements
692688
693689 def _loss (
694- self , mapping : Dict [Interval , float ], ival : Interval
695- ) -> Tuple [float , Interval ]:
690+ self , mapping : dict [Interval , float ], ival : Interval
691+ ) -> tuple [float , Interval ]:
696692 loss = mapping [ival ]
697693 return finite_loss (ival , loss , self ._scale [0 ])
698694
@@ -736,10 +732,10 @@ def remove_unfinished(self) -> None:
736732 self .losses_combined = deepcopy (self .losses )
737733 self .neighbors_combined = deepcopy (self .neighbors )
738734
739- def _get_data (self ) -> Dict [float , float ]:
735+ def _get_data (self ) -> dict [float , float ]:
740736 return self .data
741737
742- def _set_data (self , data : Dict [float , float ]) -> None :
738+ def _set_data (self , data : dict [float , float ]) -> None :
743739 if data :
744740 xs , ys = zip (* data .items ())
745741 self .tell_many (xs , ys )
@@ -763,7 +759,7 @@ def __setstate__(self, state):
763759 self .losses_combined .update (losses_combined )
764760
765761
766- def loss_manager (x_scale : float ) -> Dict [Interval , float ]:
762+ def loss_manager (x_scale : float ) -> dict [Interval , float ]:
767763 def sort_key (ival , loss ):
768764 loss , ival = finite_loss (ival , loss , x_scale )
769765 return - loss , ival
@@ -772,7 +768,7 @@ def sort_key(ival, loss):
772768 return sorted_dict
773769
774770
775- def finite_loss (ival : Interval , loss : float , x_scale : float ) -> Tuple [float , Interval ]:
771+ def finite_loss (ival : Interval , loss : float , x_scale : float ) -> tuple [float , Interval ]:
776772 """Get the so-called finite_loss of an interval in order to be able to
777773 sort intervals that have infinite loss."""
778774 # If the loss is infinite we return the
0 commit comments