11from copy import copy
2+ from functools import partial
3+ from typing import Any , List , Tuple , Union
24
5+ from numpy import float64 , ndarray
36from sortedcontainers import SortedDict , SortedSet
47
58from adaptive .learner .base_learner import BaseLearner
@@ -15,17 +18,22 @@ class _IgnoreFirstArgument:
1518 pickable.
1619 """
1720
18- def __init__ (self , function ) :
21+ def __init__ (self , function : partial ) -> None :
1922 self .function = function
2023
21- def __call__ (self , index_point , * args , ** kwargs ):
24+ def __call__ (
25+ self ,
26+ index_point : Union [Tuple [int , int ], Tuple [int , float64 ], Tuple [int , ndarray ]],
27+ * args ,
28+ ** kwargs
29+ ) -> Union [float64 , float ]:
2230 index , point = index_point
2331 return self .function (point , * args , ** kwargs )
2432
25- def __getstate__ (self ):
33+ def __getstate__ (self ) -> partial :
2634 return self .function
2735
28- def __setstate__ (self , function ) :
36+ def __setstate__ (self , function : partial ) -> None :
2937 self .__init__ (function )
3038
3139
@@ -56,7 +64,7 @@ class SequenceLearner(BaseLearner):
5664 the added benefit of having results in the local kernel already.
5765 """
5866
59- def __init__ (self , function , sequence ) :
67+ def __init__ (self , function : partial , sequence : Union [ range , ndarray ]) -> None :
6068 self ._original_function = function
6169 self .function = _IgnoreFirstArgument (function )
6270 self ._to_do_indices = SortedSet ({i for i , _ in enumerate (sequence )})
@@ -65,7 +73,13 @@ def __init__(self, function, sequence):
6573 self .data = SortedDict ()
6674 self .pending_points = set ()
6775
68- def ask (self , n , tell_pending = True ):
76+ def ask (
77+ self , n : int , tell_pending : bool = True
78+ ) -> Union [
79+ Tuple [List [Tuple [int , float64 ]], List [float ]],
80+ Tuple [List [Tuple [int , int ]], List [float ]],
81+ Tuple [List [Tuple [int , ndarray ]], List [float ]],
82+ ]:
6983 indices = []
7084 points = []
7185 loss_improvements = []
@@ -83,17 +97,17 @@ def ask(self, n, tell_pending=True):
8397
8498 return points , loss_improvements
8599
86- def _get_data (self ):
100+ def _get_data (self ) -> SortedDict :
87101 return self .data
88102
89- def _set_data (self , data ) :
103+ def _set_data (self , data : SortedDict ) -> None :
90104 if data :
91105 indices , values = zip (* data .items ())
92106 # the points aren't used by tell, so we can safely pass None
93107 points = [(i , None ) for i in indices ]
94108 self .tell_many (points , values )
95109
96- def loss (self , real = True ):
110+ def loss (self , real : bool = True ) -> float :
97111 if not (self ._to_do_indices or self .pending_points ):
98112 return 0
99113 else :
@@ -105,13 +119,19 @@ def remove_unfinished(self):
105119 self ._to_do_indices .add (i )
106120 self .pending_points = set ()
107121
108- def tell (self , point , value ):
122+ def tell (
123+ self ,
124+ point : Union [
125+ Tuple [int , int ], Tuple [int , float64 ], Tuple [int , ndarray ], Tuple [int , None ]
126+ ],
127+ value : Union [float64 , float ],
128+ ) -> None :
109129 index , point = point
110130 self .data [index ] = value
111131 self .pending_points .discard (index )
112132 self ._to_do_indices .discard (index )
113133
114- def tell_pending (self , point ) :
134+ def tell_pending (self , point : Any ) -> None :
115135 index , point = point
116136 self .pending_points .add (index )
117137 self ._to_do_indices .discard (index )
@@ -126,5 +146,5 @@ def result(self):
126146 return list (self .data .values ())
127147
128148 @property
129- def npoints (self ):
149+ def npoints (self ) -> int :
130150 return len (self .data )
0 commit comments