Skip to content

Commit f931bf1

Browse files
committed
add type annotations for adaptive/learner/learner2D.py
1 parent 4f28a4c commit f931bf1

1 file changed

Lines changed: 84 additions & 28 deletions

File tree

adaptive/learner/learner2D.py

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
import warnings
33
from collections import OrderedDict
44
from copy import copy
5+
from functools import partial
56
from math import sqrt
7+
from typing import Any, Callable, List, Optional, Tuple, Union
68

79
import numpy as np
10+
from numpy import bool_, float64, ndarray
811
from scipy import interpolate
12+
from scipy.interpolate.interpnd import LinearNDInterpolator
913

1014
from adaptive.learner.base_learner import BaseLearner
1115
from adaptive.learner.triangulation import simplex_volume_in_embedding
@@ -15,7 +19,7 @@
1519
# Learner2D and helper functions.
1620

1721

18-
def deviations(ip):
22+
def deviations(ip: LinearNDInterpolator) -> List[ndarray]:
1923
"""Returns the deviation of the linear estimate.
2024
2125
Is useful when defining custom loss functions.
@@ -52,7 +56,7 @@ def deviation(p, v, g):
5256
return devs
5357

5458

55-
def areas(ip):
59+
def areas(ip: LinearNDInterpolator) -> ndarray:
5660
"""Returns the area per triangle of the triangulation inside
5761
a `LinearNDInterpolator` instance.
5862
@@ -73,7 +77,7 @@ def areas(ip):
7377
return areas
7478

7579

76-
def uniform_loss(ip):
80+
def uniform_loss(ip: LinearNDInterpolator) -> ndarray:
7781
"""Loss function that samples the domain uniformly.
7882
7983
Works with `~adaptive.Learner2D` only.
@@ -104,7 +108,7 @@ def uniform_loss(ip):
104108
return np.sqrt(areas(ip))
105109

106110

107-
def resolution_loss_function(min_distance=0, max_distance=1):
111+
def resolution_loss_function(min_distance: int = 0, max_distance: int = 1) -> Callable:
108112
"""Loss function that is similar to the `default_loss` function, but you
109113
can set the maximimum and minimum size of a triangle.
110114
@@ -143,7 +147,7 @@ def resolution_loss(ip):
143147
return resolution_loss
144148

145149

146-
def minimize_triangle_surface_loss(ip):
150+
def minimize_triangle_surface_loss(ip: LinearNDInterpolator) -> ndarray:
147151
"""Loss function that is similar to the distance loss function in the
148152
`~adaptive.Learner1D`. The loss is the area spanned by the 3D
149153
vectors of the vertices.
@@ -189,7 +193,7 @@ def _get_vectors(points):
189193
return np.linalg.norm(np.cross(a, b) / 2, axis=1)
190194

191195

192-
def default_loss(ip):
196+
def default_loss(ip: LinearNDInterpolator) -> ndarray:
193197
"""Loss function that combines `deviations` and `areas` of the triangles.
194198
195199
Works with `~adaptive.Learner2D` only.
@@ -209,7 +213,7 @@ def default_loss(ip):
209213
return losses
210214

211215

212-
def choose_point_in_triangle(triangle, max_badness):
216+
def choose_point_in_triangle(triangle: ndarray, max_badness: int) -> ndarray:
213217
"""Choose a new point in inside a triangle.
214218
215219
If the ratio of the longest edge of the triangle squared
@@ -348,7 +352,14 @@ class Learner2D(BaseLearner):
348352
over each triangle.
349353
"""
350354

351-
def __init__(self, function, bounds, loss_per_triangle=None):
355+
def __init__(
356+
self,
357+
function: Union[partial, Callable],
358+
bounds: Union[
359+
List[Tuple[int, int]], Tuple[Tuple[int, int], Tuple[int, int]], ndarray
360+
],
361+
loss_per_triangle: Optional[Callable] = None,
362+
) -> None:
352363
self.ndim = len(bounds)
353364
self._vdim = None
354365
self.loss_per_triangle = loss_per_triangle or default_loss
@@ -369,28 +380,28 @@ def __init__(self, function, bounds, loss_per_triangle=None):
369380
self.stack_size = 10
370381

371382
@property
372-
def xy_scale(self):
383+
def xy_scale(self) -> ndarray:
373384
xy_scale = self._xy_scale
374385
if self.aspect_ratio == 1:
375386
return xy_scale
376387
else:
377388
return np.array([xy_scale[0], xy_scale[1] / self.aspect_ratio])
378389

379-
def _scale(self, points):
390+
def _scale(self, points: Any) -> ndarray:
380391
points = np.asarray(points, dtype=float)
381392
return (points - self.xy_mean) / self.xy_scale
382393

383-
def _unscale(self, points):
394+
def _unscale(self, points: ndarray) -> ndarray:
384395
points = np.asarray(points, dtype=float)
385396
return points * self.xy_scale + self.xy_mean
386397

387398
@property
388-
def npoints(self):
399+
def npoints(self) -> int:
389400
"""Number of evaluated points."""
390401
return len(self.data)
391402

392403
@property
393-
def vdim(self):
404+
def vdim(self) -> int:
394405
"""Length of the output of ``learner.function``.
395406
If the output is unsized (when it's a scalar)
396407
then `vdim = 1`.
@@ -406,7 +417,7 @@ def vdim(self):
406417
return self._vdim or 1
407418

408419
@property
409-
def bounds_are_done(self):
420+
def bounds_are_done(self) -> bool:
410421
return not any(
411422
(p in self.pending_points or p in self._stack) for p in self._bounds_points
412423
)
@@ -443,7 +454,7 @@ def interpolated_on_grid(self, n=None):
443454
xs, ys = self._unscale(np.vstack([xs, ys]).T).T
444455
return xs, ys, zs
445456

446-
def _data_in_bounds(self):
457+
def _data_in_bounds(self) -> Tuple[ndarray, ndarray]:
447458
if self.data:
448459
points = np.array(list(self.data.keys()))
449460
values = np.array(list(self.data.values()), dtype=float)
@@ -452,7 +463,7 @@ def _data_in_bounds(self):
452463
return points[inds], values[inds].reshape(-1, self.vdim)
453464
return np.zeros((0, 2)), np.zeros((0, self.vdim), dtype=float)
454465

455-
def _data_interp(self):
466+
def _data_interp(self) -> Any:
456467
if self.pending_points:
457468
points = list(self.pending_points)
458469
if self.bounds_are_done:
@@ -465,7 +476,7 @@ def _data_interp(self):
465476
return points, values
466477
return np.zeros((0, 2)), np.zeros((0, self.vdim), dtype=float)
467478

468-
def _data_combined(self):
479+
def _data_combined(self) -> Tuple[ndarray, ndarray]:
469480
points, values = self._data_in_bounds()
470481
if not self.pending_points:
471482
return points, values
@@ -483,7 +494,7 @@ def ip(self):
483494
)
484495
return self.interpolator(scaled=True)
485496

486-
def interpolator(self, *, scaled=False):
497+
def interpolator(self, *, scaled=False) -> LinearNDInterpolator:
487498
"""A `scipy.interpolate.LinearNDInterpolator` instance
488499
containing the learner's data.
489500
@@ -514,7 +525,7 @@ def interpolator(self, *, scaled=False):
514525
points, values = self._data_in_bounds()
515526
return interpolate.LinearNDInterpolator(points, values)
516527

517-
def _interpolator_combined(self):
528+
def _interpolator_combined(self) -> LinearNDInterpolator:
518529
"""A `scipy.interpolate.LinearNDInterpolator` instance
519530
containing the learner's data *and* interpolated data of
520531
the `pending_points`."""
@@ -524,12 +535,29 @@ def _interpolator_combined(self):
524535
self._ip_combined = interpolate.LinearNDInterpolator(points, values)
525536
return self._ip_combined
526537

527-
def inside_bounds(self, xy):
538+
def inside_bounds(
539+
self,
540+
xy: Union[
541+
Tuple[int, int],
542+
Tuple[float64, float],
543+
Tuple[float64, float64],
544+
Tuple[float, float64],
545+
],
546+
) -> Union[bool, bool_]:
528547
x, y = xy
529548
(xmin, xmax), (ymin, ymax) = self.bounds
530549
return xmin <= x <= xmax and ymin <= y <= ymax
531550

532-
def tell(self, point, value):
551+
def tell(
552+
self,
553+
point: Union[
554+
Tuple[int, int],
555+
Tuple[float64, float],
556+
Tuple[float64, float64],
557+
Tuple[float, float64],
558+
],
559+
value: Union[List[int], float64, float],
560+
) -> None:
533561
point = tuple(point)
534562
self.data[point] = value
535563
if not self.inside_bounds(point):
@@ -538,15 +566,43 @@ def tell(self, point, value):
538566
self._ip = None
539567
self._stack.pop(point, None)
540568

541-
def tell_pending(self, point):
569+
def tell_pending(
570+
self,
571+
point: Union[
572+
Tuple[int, int],
573+
Tuple[float64, float],
574+
Tuple[float64, float64],
575+
Tuple[float, float64],
576+
],
577+
) -> None:
542578
point = tuple(point)
543579
if not self.inside_bounds(point):
544580
return
545581
self.pending_points.add(point)
546582
self._ip_combined = None
547583
self._stack.pop(point, None)
548584

549-
def _fill_stack(self, stack_till=1):
585+
def _fill_stack(
586+
self, stack_till: int = 1
587+
) -> Union[
588+
Tuple[List[Tuple[float64, float64]], List[float64]],
589+
Tuple[
590+
List[
591+
Union[
592+
Tuple[float64, float64],
593+
Tuple[float, float64],
594+
Tuple[float64, float],
595+
]
596+
],
597+
List[float64],
598+
],
599+
Tuple[
600+
List[Union[Tuple[float, float64], Tuple[float64, float64]]], List[float64]
601+
],
602+
Tuple[
603+
List[Union[Tuple[float64, float64], Tuple[float64, float]]], List[float64]
604+
],
605+
]:
550606
if len(self.data) + len(self.pending_points) < self.ndim + 1:
551607
raise ValueError("too few points...")
552608

@@ -585,7 +641,7 @@ def _fill_stack(self, stack_till=1):
585641

586642
return points_new, losses_new
587643

588-
def ask(self, n, tell_pending=True):
644+
def ask(self, n: int, tell_pending: bool = True) -> Any:
589645
# Even if tell_pending is False we add the point such that _fill_stack
590646
# will return new points, later we remove these points if needed.
591647
points = list(self._stack.keys())
@@ -616,14 +672,14 @@ def ask(self, n, tell_pending=True):
616672
return points[:n], loss_improvements[:n]
617673

618674
@cache_latest
619-
def loss(self, real=True):
675+
def loss(self, real: bool = True) -> float64:
620676
if not self.bounds_are_done:
621677
return np.inf
622678
ip = self.interpolator(scaled=True) if real else self._interpolator_combined()
623679
losses = self.loss_per_triangle(ip)
624680
return losses.max()
625681

626-
def remove_unfinished(self):
682+
def remove_unfinished(self) -> None:
627683
self.pending_points = set()
628684
for p in self._bounds_points:
629685
if p not in self.data:
@@ -697,10 +753,10 @@ def plot(self, n=None, tri_alpha=0):
697753

698754
return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover)
699755

700-
def _get_data(self):
756+
def _get_data(self) -> OrderedDict:
701757
return self.data
702758

703-
def _set_data(self, data):
759+
def _set_data(self, data: OrderedDict) -> None:
704760
self.data = data
705761
# Remove points from stack if they already exist
706762
for point in copy(self._stack):

0 commit comments

Comments
 (0)