Skip to content

Commit 4f28a4c

Browse files
committed
add type annotations for adaptive/learner/learner1D.py
1 parent 5ae3eb4 commit 4f28a4c

1 file changed

Lines changed: 72 additions & 31 deletions

File tree

adaptive/learner/learner1D.py

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22
import math
33
from collections.abc import Iterable
44
from copy import deepcopy
5+
from functools import partial
6+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
57

68
import numpy as np
79
import sortedcollections
810
import sortedcontainers
11+
from numpy import float64, ndarray
12+
from sortedcollections.recipes import ItemSortedDict
13+
from sortedcontainers.sorteddict import SortedDict
914

1015
from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors
1116
from adaptive.learner.learnerND import volume
@@ -15,7 +20,10 @@
1520

1621

1722
@uses_nth_neighbors(0)
18-
def uniform_loss(xs, ys):
23+
def uniform_loss(
24+
xs: Union[Tuple[float, float], Tuple[float64, float64]],
25+
ys: Union[Tuple[float, float], Tuple[float64, float64]],
26+
) -> Union[float64, float]:
1927
"""Loss function that samples the domain uniformly.
2028
2129
Works with `~adaptive.Learner1D` only.
@@ -35,7 +43,15 @@ def uniform_loss(xs, ys):
3543

3644

3745
@uses_nth_neighbors(0)
38-
def default_loss(xs, ys):
46+
def default_loss(
47+
xs: Union[
48+
Tuple[float, float],
49+
Tuple[float64, float],
50+
Tuple[float64, float64],
51+
Tuple[float, float64],
52+
],
53+
ys: Union[Tuple[float, float], Tuple[ndarray, ndarray], Tuple[float64, float64]],
54+
) -> float64:
3955
"""Calculate loss on a single interval.
4056
4157
Currently returns the rescaled length of the interval. If one of the
@@ -52,7 +68,7 @@ def default_loss(xs, ys):
5268

5369

5470
@uses_nth_neighbors(1)
55-
def triangle_loss(xs, ys):
71+
def triangle_loss(xs: Any, ys: Any) -> Union[float64, float]:
5672
xs = [x for x in xs if x is not None]
5773
ys = [y for y in ys if y is not None]
5874

@@ -69,7 +85,9 @@ def triangle_loss(xs, ys):
6985
return sum(vol(pts[i : i + 3]) for i in range(N)) / N
7086

7187

72-
def curvature_loss_function(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
88+
def curvature_loss_function(
89+
area_factor: int = 1, euclid_factor: float = 0.02, horizontal_factor: float = 0.02
90+
) -> Callable:
7391
# XXX: add a doc-string
7492
@uses_nth_neighbors(1)
7593
def curvature_loss(xs, ys):
@@ -88,7 +106,9 @@ def curvature_loss(xs, ys):
88106
return curvature_loss
89107

90108

91-
def linspace(x_left, x_right, n):
109+
def linspace(
110+
x_left: Union[int, float64, float], x_right: Union[int, float64, float], n: int
111+
) -> Union[List[float], List[float64]]:
92112
"""This is equivalent to
93113
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
94114
but it is 15-30 times faster for small 'n'."""
@@ -100,7 +120,7 @@ def linspace(x_left, x_right, n):
100120
return [x_left + step * i for i in range(1, n)]
101121

102122

103-
def _get_neighbors_from_list(xs):
123+
def _get_neighbors_from_list(xs: ndarray) -> SortedDict:
104124
xs = np.sort(xs)
105125
xs_left = np.roll(xs, 1).tolist()
106126
xs_right = np.roll(xs, -1).tolist()
@@ -110,7 +130,9 @@ def _get_neighbors_from_list(xs):
110130
return sortedcontainers.SortedDict(neighbors)
111131

112132

113-
def _get_intervals(x, neighbors, nth_neighbors):
133+
def _get_intervals(
134+
x: Union[int, float64, float], neighbors: SortedDict, nth_neighbors: int
135+
) -> Any:
114136
nn = nth_neighbors
115137
i = neighbors.index(x)
116138
start = max(0, i - nn - 1)
@@ -163,7 +185,12 @@ class Learner1D(BaseLearner):
163185
decorator for more information.
164186
"""
165187

166-
def __init__(self, function, bounds, loss_per_interval=None):
188+
def __init__(
189+
self,
190+
function: Union[Callable, partial],
191+
bounds: Union[Tuple[int, int], Tuple[float, float], ndarray],
192+
loss_per_interval: Optional[Callable] = None,
193+
) -> None:
167194
self.function = function
168195

169196
if hasattr(loss_per_interval, "nth_neighbors"):
@@ -205,7 +232,7 @@ def __init__(self, function, bounds, loss_per_interval=None):
205232
self._vdim = None
206233

207234
@property
208-
def vdim(self):
235+
def vdim(self) -> int:
209236
"""Length of the output of ``learner.function``.
210237
If the output is unsized (when it's a scalar)
211238
then `vdim = 1`.
@@ -225,35 +252,41 @@ def vdim(self):
225252
return self._vdim
226253

227254
@property
228-
def npoints(self):
255+
def npoints(self) -> int:
229256
"""Number of evaluated points."""
230257
return len(self.data)
231258

232259
@cache_latest
233-
def loss(self, real=True):
260+
def loss(self, real: bool = True) -> Union[int, float64, float]:
234261
losses = self.losses if real else self.losses_combined
235262
if not losses:
236263
return np.inf
237264
max_interval, max_loss = losses.peekitem(0)
238265
return max_loss
239266

240-
def _scale_x(self, x):
267+
def _scale_x(
268+
self, x: Optional[Union[float, int, float64]]
269+
) -> Optional[Union[float, float64]]:
241270
if x is None:
242271
return None
243272
return x / self._scale[0]
244273

245-
def _scale_y(self, y):
274+
def _scale_y(
275+
self, y: Optional[Union[int, ndarray, float64, float]]
276+
) -> Optional[Union[float, float64, ndarray]]:
246277
if y is None:
247278
return None
248279
y_scale = self._scale[1] or 1
249280
return y / y_scale
250281

251-
def _get_point_by_index(self, ind):
282+
def _get_point_by_index(self, ind: int) -> Optional[Union[int, float64, float]]:
252283
if ind < 0 or ind >= len(self.neighbors):
253284
return None
254285
return self.neighbors.keys()[ind]
255286

256-
def _get_loss_in_interval(self, x_left, x_right):
287+
def _get_loss_in_interval(
288+
self, x_left: Union[int, float64, float], x_right: Union[int, float64, float]
289+
) -> Union[int, float64, float]:
257290
assert x_left is not None and x_right is not None
258291

259292
if x_right - x_left < self._dx_eps:
@@ -273,7 +306,9 @@ def _get_loss_in_interval(self, x_left, x_right):
273306
# we need to compute the loss for this interval
274307
return self.loss_per_interval(xs_scaled, ys_scaled)
275308

276-
def _update_interpolated_loss_in_interval(self, x_left, x_right):
309+
def _update_interpolated_loss_in_interval(
310+
self, x_left: Union[int, float64, float], x_right: Union[int, float64, float]
311+
) -> None:
277312
if x_left is None or x_right is None:
278313
return
279314

@@ -289,7 +324,7 @@ def _update_interpolated_loss_in_interval(self, x_left, x_right):
289324
self.losses_combined[a, b] = (b - a) * loss / dx
290325
a = b
291326

292-
def _update_losses(self, x, real=True):
327+
def _update_losses(self, x: Union[int, float64, float], real: bool = True) -> None:
293328
"""Update all losses that depend on x"""
294329
# When we add a new point x, we should update the losses
295330
# (x_left, x_right) are the "real" neighbors of 'x'.
@@ -332,7 +367,7 @@ def _update_losses(self, x, real=True):
332367
self.losses_combined[x, b] = float("inf")
333368

334369
@staticmethod
335-
def _find_neighbors(x, neighbors):
370+
def _find_neighbors(x: Union[int, float64, float], neighbors: SortedDict) -> Any:
336371
if x in neighbors:
337372
return neighbors[x]
338373
pos = neighbors.bisect_left(x)
@@ -341,14 +376,18 @@ def _find_neighbors(x, neighbors):
341376
x_right = keys[pos] if pos != len(neighbors) else None
342377
return x_left, x_right
343378

344-
def _update_neighbors(self, x, neighbors):
379+
def _update_neighbors(
380+
self, x: Union[int, float64, float], neighbors: SortedDict
381+
) -> None:
345382
if x not in neighbors: # The point is new
346383
x_left, x_right = self._find_neighbors(x, neighbors)
347384
neighbors[x] = [x_left, x_right]
348385
neighbors.get(x_left, [None, None])[1] = x
349386
neighbors.get(x_right, [None, None])[0] = x
350387

351-
def _update_scale(self, x, y):
388+
def _update_scale(
389+
self, x: Union[int, float64, float], y: Union[float, int, float64, ndarray]
390+
) -> None:
352391
"""Update the scale with which the x and y-values are scaled.
353392
354393
For a learner where the function returns a single scalar the scale
@@ -375,7 +414,7 @@ def _update_scale(self, x, y):
375414
self._bbox[1][1] = max(self._bbox[1][1], y)
376415
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
377416

378-
def tell(self, x, y):
417+
def tell(self, x: Union[int, float64, float], y: Any) -> None:
379418
if x in self.data:
380419
# The point is already evaluated before
381420
return
@@ -410,15 +449,15 @@ def tell(self, x, y):
410449

411450
self._oldscale = deepcopy(self._scale)
412451

413-
def tell_pending(self, x):
452+
def tell_pending(self, x: Union[int, float64, float]) -> None:
414453
if x in self.data:
415454
# The point is already evaluated before
416455
return
417456
self.pending_points.add(x)
418457
self._update_neighbors(x, self.neighbors_combined)
419458
self._update_losses(x, real=False)
420459

421-
def tell_many(self, xs, ys, *, force=False):
460+
def tell_many(self, xs: Any, ys: Any, *, force=False) -> None:
422461
if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
423462
# Only run this more efficient method if there are
424463
# at least 2 points and the amount of points added are
@@ -486,7 +525,7 @@ def tell_many(self, xs, ys, *, force=False):
486525
# have an inf loss.
487526
self._update_interpolated_loss_in_interval(*ival)
488527

489-
def ask(self, n, tell_pending=True):
528+
def ask(self, n: int, tell_pending: bool = True) -> Any:
490529
"""Return 'n' points that are expected to maximally reduce the loss."""
491530
points, loss_improvements = self._ask_points_without_adding(n)
492531

@@ -496,7 +535,7 @@ def ask(self, n, tell_pending=True):
496535

497536
return points, loss_improvements
498537

499-
def _ask_points_without_adding(self, n):
538+
def _ask_points_without_adding(self, n: int) -> Any:
500539
"""Return 'n' points that are expected to maximally reduce the loss.
501540
Without altering the state of the learner"""
502541
# Find out how to divide the n points over the intervals
@@ -574,7 +613,7 @@ def _ask_points_without_adding(self, n):
574613

575614
return points, loss_improvements
576615

577-
def _loss(self, mapping, ival):
616+
def _loss(self, mapping: ItemSortedDict, ival: Any) -> Any:
578617
loss = mapping[ival]
579618
return finite_loss(ival, loss, self._scale[0])
580619

@@ -613,20 +652,20 @@ def plot(self, *, scatter_or_line="scatter"):
613652

614653
return p.redim(x=dict(range=plot_bounds))
615654

616-
def remove_unfinished(self):
655+
def remove_unfinished(self) -> None:
617656
self.pending_points = set()
618657
self.losses_combined = deepcopy(self.losses)
619658
self.neighbors_combined = deepcopy(self.neighbors)
620659

621-
def _get_data(self):
660+
def _get_data(self) -> Dict[Union[int, float], float]:
622661
return self.data
623662

624-
def _set_data(self, data):
663+
def _set_data(self, data: Dict[Union[int, float], float]) -> None:
625664
if data:
626665
self.tell_many(*zip(*data.items()))
627666

628667

629-
def loss_manager(x_scale):
668+
def loss_manager(x_scale: Union[int, float64, float]) -> ItemSortedDict:
630669
def sort_key(ival, loss):
631670
loss, ival = finite_loss(ival, loss, x_scale)
632671
return -loss, ival
@@ -635,7 +674,9 @@ def sort_key(ival, loss):
635674
return sorted_dict
636675

637676

638-
def finite_loss(ival, loss, x_scale):
677+
def finite_loss(
678+
ival: Any, loss: Union[int, float64, float], x_scale: Union[int, float64, float]
679+
) -> Any:
639680
"""Get the socalled finite_loss of an interval in order to be able to
640681
sort intervals that have infinite loss."""
641682
# If the loss is infinite we return the

0 commit comments

Comments
 (0)