22import warnings
33from collections import OrderedDict
44from copy import copy
5+ from functools import partial
56from math import sqrt
7+ from typing import Any , Callable , List , Optional , Tuple , Union
68
79import numpy as np
10+ from numpy import bool_ , float64 , ndarray
811from scipy import interpolate
12+ from scipy .interpolate .interpnd import LinearNDInterpolator
913
1014from adaptive .learner .base_learner import BaseLearner
1115from adaptive .learner .triangulation import simplex_volume_in_embedding
1519# Learner2D and helper functions.
1620
1721
18- def deviations (ip ) :
22+ def deviations (ip : LinearNDInterpolator ) -> List [ ndarray ] :
1923 """Returns the deviation of the linear estimate.
2024
2125 Is useful when defining custom loss functions.
@@ -52,7 +56,7 @@ def deviation(p, v, g):
5256 return devs
5357
5458
55- def areas (ip ) :
59+ def areas (ip : LinearNDInterpolator ) -> ndarray :
5660 """Returns the area per triangle of the triangulation inside
5761 a `LinearNDInterpolator` instance.
5862
@@ -73,7 +77,7 @@ def areas(ip):
7377 return areas
7478
7579
76- def uniform_loss (ip ) :
80+ def uniform_loss (ip : LinearNDInterpolator ) -> ndarray :
7781 """Loss function that samples the domain uniformly.
7882
7983 Works with `~adaptive.Learner2D` only.
@@ -104,7 +108,7 @@ def uniform_loss(ip):
104108 return np .sqrt (areas (ip ))
105109
106110
107- def resolution_loss_function (min_distance = 0 , max_distance = 1 ) :
111+ def resolution_loss_function (min_distance : int = 0 , max_distance : int = 1 ) -> Callable :
108112 """Loss function that is similar to the `default_loss` function, but you
109113 can set the maximimum and minimum size of a triangle.
110114
@@ -143,7 +147,7 @@ def resolution_loss(ip):
143147 return resolution_loss
144148
145149
146- def minimize_triangle_surface_loss (ip ) :
150+ def minimize_triangle_surface_loss (ip : LinearNDInterpolator ) -> ndarray :
147151 """Loss function that is similar to the distance loss function in the
148152 `~adaptive.Learner1D`. The loss is the area spanned by the 3D
149153 vectors of the vertices.
@@ -189,7 +193,7 @@ def _get_vectors(points):
189193 return np .linalg .norm (np .cross (a , b ) / 2 , axis = 1 )
190194
191195
192- def default_loss (ip ) :
196+ def default_loss (ip : LinearNDInterpolator ) -> ndarray :
193197 """Loss function that combines `deviations` and `areas` of the triangles.
194198
195199 Works with `~adaptive.Learner2D` only.
@@ -209,7 +213,7 @@ def default_loss(ip):
209213 return losses
210214
211215
212- def choose_point_in_triangle (triangle , max_badness ) :
216+ def choose_point_in_triangle (triangle : ndarray , max_badness : int ) -> ndarray :
213217 """Choose a new point in inside a triangle.
214218
215219 If the ratio of the longest edge of the triangle squared
@@ -348,7 +352,14 @@ class Learner2D(BaseLearner):
348352 over each triangle.
349353 """
350354
351- def __init__ (self , function , bounds , loss_per_triangle = None ):
355+ def __init__ (
356+ self ,
357+ function : Union [partial , Callable ],
358+ bounds : Union [
359+ List [Tuple [int , int ]], Tuple [Tuple [int , int ], Tuple [int , int ]], ndarray
360+ ],
361+ loss_per_triangle : Optional [Callable ] = None ,
362+ ) -> None :
352363 self .ndim = len (bounds )
353364 self ._vdim = None
354365 self .loss_per_triangle = loss_per_triangle or default_loss
@@ -369,28 +380,28 @@ def __init__(self, function, bounds, loss_per_triangle=None):
369380 self .stack_size = 10
370381
371382 @property
372- def xy_scale (self ):
383+ def xy_scale (self ) -> ndarray :
373384 xy_scale = self ._xy_scale
374385 if self .aspect_ratio == 1 :
375386 return xy_scale
376387 else :
377388 return np .array ([xy_scale [0 ], xy_scale [1 ] / self .aspect_ratio ])
378389
379- def _scale (self , points ) :
390+ def _scale (self , points : Any ) -> ndarray :
380391 points = np .asarray (points , dtype = float )
381392 return (points - self .xy_mean ) / self .xy_scale
382393
383- def _unscale (self , points ) :
394+ def _unscale (self , points : ndarray ) -> ndarray :
384395 points = np .asarray (points , dtype = float )
385396 return points * self .xy_scale + self .xy_mean
386397
387398 @property
388- def npoints (self ):
399+ def npoints (self ) -> int :
389400 """Number of evaluated points."""
390401 return len (self .data )
391402
392403 @property
393- def vdim (self ):
404+ def vdim (self ) -> int :
394405 """Length of the output of ``learner.function``.
395406 If the output is unsized (when it's a scalar)
396407 then `vdim = 1`.
@@ -406,7 +417,7 @@ def vdim(self):
406417 return self ._vdim or 1
407418
408419 @property
409- def bounds_are_done (self ):
420+ def bounds_are_done (self ) -> bool :
410421 return not any (
411422 (p in self .pending_points or p in self ._stack ) for p in self ._bounds_points
412423 )
@@ -443,7 +454,7 @@ def interpolated_on_grid(self, n=None):
443454 xs , ys = self ._unscale (np .vstack ([xs , ys ]).T ).T
444455 return xs , ys , zs
445456
446- def _data_in_bounds (self ):
457+ def _data_in_bounds (self ) -> Tuple [ ndarray , ndarray ] :
447458 if self .data :
448459 points = np .array (list (self .data .keys ()))
449460 values = np .array (list (self .data .values ()), dtype = float )
@@ -452,7 +463,7 @@ def _data_in_bounds(self):
452463 return points [inds ], values [inds ].reshape (- 1 , self .vdim )
453464 return np .zeros ((0 , 2 )), np .zeros ((0 , self .vdim ), dtype = float )
454465
455- def _data_interp (self ):
466+ def _data_interp (self ) -> Any :
456467 if self .pending_points :
457468 points = list (self .pending_points )
458469 if self .bounds_are_done :
@@ -465,7 +476,7 @@ def _data_interp(self):
465476 return points , values
466477 return np .zeros ((0 , 2 )), np .zeros ((0 , self .vdim ), dtype = float )
467478
468- def _data_combined (self ):
479+ def _data_combined (self ) -> Tuple [ ndarray , ndarray ] :
469480 points , values = self ._data_in_bounds ()
470481 if not self .pending_points :
471482 return points , values
@@ -483,7 +494,7 @@ def ip(self):
483494 )
484495 return self .interpolator (scaled = True )
485496
486- def interpolator (self , * , scaled = False ):
497+ def interpolator (self , * , scaled = False ) -> LinearNDInterpolator :
487498 """A `scipy.interpolate.LinearNDInterpolator` instance
488499 containing the learner's data.
489500
@@ -514,7 +525,7 @@ def interpolator(self, *, scaled=False):
514525 points , values = self ._data_in_bounds ()
515526 return interpolate .LinearNDInterpolator (points , values )
516527
517- def _interpolator_combined (self ):
528+ def _interpolator_combined (self ) -> LinearNDInterpolator :
518529 """A `scipy.interpolate.LinearNDInterpolator` instance
519530 containing the learner's data *and* interpolated data of
520531 the `pending_points`."""
@@ -524,12 +535,29 @@ def _interpolator_combined(self):
524535 self ._ip_combined = interpolate .LinearNDInterpolator (points , values )
525536 return self ._ip_combined
526537
527- def inside_bounds (self , xy ):
538+ def inside_bounds (
539+ self ,
540+ xy : Union [
541+ Tuple [int , int ],
542+ Tuple [float64 , float ],
543+ Tuple [float64 , float64 ],
544+ Tuple [float , float64 ],
545+ ],
546+ ) -> Union [bool , bool_ ]:
528547 x , y = xy
529548 (xmin , xmax ), (ymin , ymax ) = self .bounds
530549 return xmin <= x <= xmax and ymin <= y <= ymax
531550
532- def tell (self , point , value ):
551+ def tell (
552+ self ,
553+ point : Union [
554+ Tuple [int , int ],
555+ Tuple [float64 , float ],
556+ Tuple [float64 , float64 ],
557+ Tuple [float , float64 ],
558+ ],
559+ value : Union [List [int ], float64 , float ],
560+ ) -> None :
533561 point = tuple (point )
534562 self .data [point ] = value
535563 if not self .inside_bounds (point ):
@@ -538,15 +566,43 @@ def tell(self, point, value):
538566 self ._ip = None
539567 self ._stack .pop (point , None )
540568
541- def tell_pending (self , point ):
569+ def tell_pending (
570+ self ,
571+ point : Union [
572+ Tuple [int , int ],
573+ Tuple [float64 , float ],
574+ Tuple [float64 , float64 ],
575+ Tuple [float , float64 ],
576+ ],
577+ ) -> None :
542578 point = tuple (point )
543579 if not self .inside_bounds (point ):
544580 return
545581 self .pending_points .add (point )
546582 self ._ip_combined = None
547583 self ._stack .pop (point , None )
548584
549- def _fill_stack (self , stack_till = 1 ):
585+ def _fill_stack (
586+ self , stack_till : int = 1
587+ ) -> Union [
588+ Tuple [List [Tuple [float64 , float64 ]], List [float64 ]],
589+ Tuple [
590+ List [
591+ Union [
592+ Tuple [float64 , float64 ],
593+ Tuple [float , float64 ],
594+ Tuple [float64 , float ],
595+ ]
596+ ],
597+ List [float64 ],
598+ ],
599+ Tuple [
600+ List [Union [Tuple [float , float64 ], Tuple [float64 , float64 ]]], List [float64 ]
601+ ],
602+ Tuple [
603+ List [Union [Tuple [float64 , float64 ], Tuple [float64 , float ]]], List [float64 ]
604+ ],
605+ ]:
550606 if len (self .data ) + len (self .pending_points ) < self .ndim + 1 :
551607 raise ValueError ("too few points..." )
552608
@@ -585,7 +641,7 @@ def _fill_stack(self, stack_till=1):
585641
586642 return points_new , losses_new
587643
588- def ask (self , n , tell_pending = True ):
644+ def ask (self , n : int , tell_pending : bool = True ) -> Any :
589645 # Even if tell_pending is False we add the point such that _fill_stack
590646 # will return new points, later we remove these points if needed.
591647 points = list (self ._stack .keys ())
@@ -616,14 +672,14 @@ def ask(self, n, tell_pending=True):
616672 return points [:n ], loss_improvements [:n ]
617673
618674 @cache_latest
619- def loss (self , real = True ):
675+ def loss (self , real : bool = True ) -> float64 :
620676 if not self .bounds_are_done :
621677 return np .inf
622678 ip = self .interpolator (scaled = True ) if real else self ._interpolator_combined ()
623679 losses = self .loss_per_triangle (ip )
624680 return losses .max ()
625681
626- def remove_unfinished (self ):
682+ def remove_unfinished (self ) -> None :
627683 self .pending_points = set ()
628684 for p in self ._bounds_points :
629685 if p not in self .data :
@@ -697,10 +753,10 @@ def plot(self, n=None, tri_alpha=0):
697753
698754 return im .opts (style = im_opts ) * tris .opts (style = tri_opts , ** no_hover )
699755
700- def _get_data (self ):
756+ def _get_data (self ) -> OrderedDict :
701757 return self .data
702758
703- def _set_data (self , data ) :
759+ def _set_data (self , data : OrderedDict ) -> None :
704760 self .data = data
705761 # Remove points from stack if they already exist
706762 for point in copy (self ._stack ):
0 commit comments