1+ from functools import partial
12from math import sqrt
3+ from typing import Callable , Dict , List , Optional , Tuple , Union
24
35import numpy as np
46
@@ -30,7 +32,12 @@ class AverageLearner(BaseLearner):
3032 Number of evaluated points.
3133 """
3234
33- def __init__ (self , function , atol = None , rtol = None ):
35+ def __init__ (
36+ self ,
37+ function : Union [partial , Callable ],
38+ atol : None = None ,
39+ rtol : Optional [Union [int , float ]] = None ,
40+ ) -> None :
3441 if atol is None and rtol is None :
3542 raise Exception ("At least one of `atol` and `rtol` should be set." )
3643 if atol is None :
@@ -48,10 +55,10 @@ def __init__(self, function, atol=None, rtol=None):
4855 self .sum_f_sq = 0
4956
5057 @property
51- def n_requested (self ):
58+ def n_requested (self ) -> int :
5259 return self .npoints + len (self .pending_points )
5360
54- def ask (self , n , tell_pending = True ):
61+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [ List [ int ], List [ float ]] :
5562 points = list (range (self .n_requested , self .n_requested + n ))
5663
5764 if any (p in self .data or p in self .pending_points for p in points ):
@@ -68,7 +75,7 @@ def ask(self, n, tell_pending=True):
6875 self .tell_pending (p )
6976 return points , loss_improvements
7077
71- def tell (self , n , value ) :
78+ def tell (self , n : int , value : Union [ int , float ]) -> None :
7279 if n in self .data :
7380 # The point has already been added before.
7481 return
@@ -79,16 +86,16 @@ def tell(self, n, value):
7986 self .sum_f_sq += value ** 2
8087 self .npoints += 1
8188
82- def tell_pending (self , n ) :
89+ def tell_pending (self , n : int ) -> None :
8390 self .pending_points .add (n )
8491
8592 @property
86- def mean (self ):
93+ def mean (self ) -> float :
8794 """The average of all values in `data`."""
8895 return self .sum_f / self .npoints
8996
9097 @property
91- def std (self ):
98+ def std (self ) -> float :
9299 """The corrected sample standard deviation of the values
93100 in `data`."""
94101 n = self .npoints
@@ -101,7 +108,7 @@ def std(self):
101108 return sqrt (numerator / (n - 1 ))
102109
103110 @cache_latest
104- def loss (self , real = True , * , n = None ):
111+ def loss (self , real : bool = True , * , n = None ) -> float :
105112 if n is None :
106113 n = self .npoints if real else self .n_requested
107114 else :
@@ -113,7 +120,7 @@ def loss(self, real=True, *, n=None):
113120 standard_error / self .atol , standard_error / abs (self .mean ) / self .rtol
114121 )
115122
116- def _loss_improvement (self , n ) :
123+ def _loss_improvement (self , n : int ) -> float :
117124 loss = self .loss ()
118125 if np .isfinite (loss ):
119126 return loss - self .loss (n = self .npoints + n )
@@ -139,8 +146,8 @@ def plot(self):
139146 vals = hv .Points (vals )
140147 return hv .operation .histogram (vals , num_bins = num_bins , dimension = 1 )
141148
142- def _get_data (self ):
149+ def _get_data (self ) -> Tuple [ Dict [ int , float ], int , float , float ] :
143150 return (self .data , self .npoints , self .sum_f , self .sum_f_sq )
144151
145- def _set_data (self , data ) :
152+ def _set_data (self , data : Tuple [ Dict [ int , float ], int , float , float ]) -> None :
146153 self .data , self .npoints , self .sum_f , self .sum_f_sq = data
0 commit comments