@@ -20,6 +20,8 @@ class AverageLearner(BaseLearner):
2020 Desired absolute tolerance.
2121 rtol : float
2222 Desired relative tolerance.
23+ min_npoints : int
24+ Minimum number of points to sample.
2325
2426 Attributes
2527 ----------
@@ -36,6 +38,7 @@ def __init__(
3638 function : Callable ,
3739 atol : Optional [float ] = None ,
3840 rtol : Optional [float ] = None ,
41+ min_npoints : int = 2 ,
3942 ) -> None :
4043 if atol is None and rtol is None :
4144 raise Exception ("At least one of `atol` and `rtol` should be set." )
@@ -50,6 +53,8 @@ def __init__(
5053 self .atol = atol
5154 self .rtol = rtol
5255 self .npoints = 0
56+ # Cannot estimate standard deviation with fewer than 2 points.
57+ self .min_npoints = max (min_npoints , 2 )
5358 self .sum_f = 0
5459 self .sum_f_sq = 0
5560
@@ -98,7 +103,7 @@ def std(self) -> float:
98103 """The corrected sample standard deviation of the values
99104 in `data`."""
100105 n = self .npoints
101- if n < 2 :
106+ if n < self . min_npoints :
102107 return np .inf
103108 numerator = self .sum_f_sq - n * self .mean ** 2
104109 if numerator < 0 :
@@ -112,7 +117,7 @@ def loss(self, real: bool = True, *, n=None) -> float:
112117 n = self .npoints if real else self .n_requested
113118 else :
114119 n = n
115- if n < 2 :
120+ if n < self . min_npoints :
116121 return np .inf
117122 standard_error = self .std / sqrt (n )
118123 return max (
@@ -143,10 +148,24 @@ def plot(self):
143148 return hv .Histogram ([[], []])
144149 num_bins = int (max (5 , sqrt (self .npoints )))
145150 vals = hv .Points (vals )
146- return hv .operation .histogram (vals , num_bins = num_bins , dimension = 1 )
151+ return hv .operation .histogram (vals , num_bins = num_bins , dimension = "y" )
147152
148153 def _get_data (self ) -> Tuple [Dict [int , float ], int , float , float ]:
149154 return (self .data , self .npoints , self .sum_f , self .sum_f_sq )
150155
151156 def _set_data (self , data : Tuple [Dict [int , float ], int , float , float ]) -> None :
152157 self .data , self .npoints , self .sum_f , self .sum_f_sq = data
158+
159+ def __getstate__ (self ):
160+ return (
161+ self .function ,
162+ self .atol ,
163+ self .rtol ,
164+ self .min_npoints ,
165+ self ._get_data (),
166+ )
167+
168+ def __setstate__ (self , state ):
169+ function , atol , rtol , min_npoints , data = state
170+ self .__init__ (function , atol , rtol , min_npoints )
171+ self ._set_data (data )
0 commit comments