@@ -19,6 +19,8 @@ class AverageLearner(BaseLearner):
1919 Desired absolute tolerance.
2020 rtol : float
2121 Desired relative tolerance.
22+ min_npoints : int
23+ Minimum number of points to sample.
2224
2325 Attributes
2426 ----------
@@ -30,7 +32,7 @@ class AverageLearner(BaseLearner):
3032 Number of evaluated points.
3133 """
3234
33- def __init__ (self , function , atol = None , rtol = None ):
35+ def __init__ (self , function , atol = None , rtol = None , min_npoints = 2 ):
3436 if atol is None and rtol is None :
3537 raise Exception ("At least one of `atol` and `rtol` should be set." )
3638 if atol is None :
@@ -44,6 +46,8 @@ def __init__(self, function, atol=None, rtol=None):
4446 self .atol = atol
4547 self .rtol = rtol
4648 self .npoints = 0
49+ # Cannot estimate standard deviation with fewer than 2 points.
50+ self .min_npoints = max (min_npoints , 2 )
4751 self .sum_f = 0
4852 self .sum_f_sq = 0
4953
@@ -92,7 +96,7 @@ def std(self):
9296 """The corrected sample standard deviation of the values
9397 in `data`."""
9498 n = self .npoints
95- if n < 2 :
99+ if n < self . min_npoints :
96100 return np .inf
97101 numerator = self .sum_f_sq - n * self .mean ** 2
98102 if numerator < 0 :
@@ -106,7 +110,7 @@ def loss(self, real=True, *, n=None):
106110 n = self .npoints if real else self .n_requested
107111 else :
108112 n = n
109- if n < 2 :
113+ if n < self . min_npoints :
110114 return np .inf
111115 standard_error = self .std / sqrt (n )
112116 return max (
@@ -150,10 +154,11 @@ def __getstate__(self):
150154 self .function ,
151155 self .atol ,
152156 self .rtol ,
157+ self .min_npoints ,
153158 self ._get_data (),
154159 )
155160
156161 def __setstate__ (self , state ):
157- function , atol , rtol , data = state
158- self .__init__ (function , atol , rtol )
162+ function , atol , rtol , min_npoints , data = state
163+ self .__init__ (function , atol , rtol , min_npoints )
159164 self ._set_data (data )
0 commit comments