Skip to content

Commit dd21310

Browse files
committed
add type hints for adaptive/learner/average_learner.py
1 parent 53d6c7e commit dd21310

1 file changed

Lines changed: 18 additions & 11 deletions

File tree

adaptive/learner/average_learner.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from functools import partial
12
from math import sqrt
3+
from typing import Callable, Dict, List, Optional, Tuple, Union
24

35
import numpy as np
46

@@ -30,7 +32,12 @@ class AverageLearner(BaseLearner):
3032
Number of evaluated points.
3133
"""
3234

33-
def __init__(self, function, atol=None, rtol=None):
35+
def __init__(
36+
self,
37+
function: Union[partial, Callable],
38+
atol: None = None,
39+
rtol: Optional[Union[int, float]] = None,
40+
) -> None:
3441
if atol is None and rtol is None:
3542
raise Exception("At least one of `atol` and `rtol` should be set.")
3643
if atol is None:
@@ -48,10 +55,10 @@ def __init__(self, function, atol=None, rtol=None):
4855
self.sum_f_sq = 0
4956

5057
@property
51-
def n_requested(self):
58+
def n_requested(self) -> int:
5259
return self.npoints + len(self.pending_points)
5360

54-
def ask(self, n, tell_pending=True):
61+
def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[int], List[float]]:
5562
points = list(range(self.n_requested, self.n_requested + n))
5663

5764
if any(p in self.data or p in self.pending_points for p in points):
@@ -68,7 +75,7 @@ def ask(self, n, tell_pending=True):
6875
self.tell_pending(p)
6976
return points, loss_improvements
7077

71-
def tell(self, n, value):
78+
def tell(self, n: int, value: Union[int, float]) -> None:
7279
if n in self.data:
7380
# The point has already been added before.
7481
return
@@ -79,16 +86,16 @@ def tell(self, n, value):
7986
self.sum_f_sq += value ** 2
8087
self.npoints += 1
8188

82-
def tell_pending(self, n):
89+
def tell_pending(self, n: int) -> None:
8390
self.pending_points.add(n)
8491

8592
@property
86-
def mean(self):
93+
def mean(self) -> float:
8794
"""The average of all values in `data`."""
8895
return self.sum_f / self.npoints
8996

9097
@property
91-
def std(self):
98+
def std(self) -> float:
9299
"""The corrected sample standard deviation of the values
93100
in `data`."""
94101
n = self.npoints
@@ -101,7 +108,7 @@ def std(self):
101108
return sqrt(numerator / (n - 1))
102109

103110
@cache_latest
104-
def loss(self, real=True, *, n=None):
111+
def loss(self, real: bool = True, *, n=None) -> float:
105112
if n is None:
106113
n = self.npoints if real else self.n_requested
107114
else:
@@ -113,7 +120,7 @@ def loss(self, real=True, *, n=None):
113120
standard_error / self.atol, standard_error / abs(self.mean) / self.rtol
114121
)
115122

116-
def _loss_improvement(self, n):
123+
def _loss_improvement(self, n: int) -> float:
117124
loss = self.loss()
118125
if np.isfinite(loss):
119126
return loss - self.loss(n=self.npoints + n)
@@ -139,8 +146,8 @@ def plot(self):
139146
vals = hv.Points(vals)
140147
return hv.operation.histogram(vals, num_bins=num_bins, dimension=1)
141148

142-
def _get_data(self):
149+
def _get_data(self) -> Tuple[Dict[int, float], int, float, float]:
143150
return (self.data, self.npoints, self.sum_f, self.sum_f_sq)
144151

145-
def _set_data(self, data):
152+
def _set_data(self, data: Tuple[Dict[int, float], int, float, float]) -> None:
146153
self.data, self.npoints, self.sum_f, self.sum_f_sq = data

0 commit comments

Comments
 (0)