22
33import sys
44from collections import defaultdict
5+ from functools import partial
56from math import sqrt
67from operator import attrgetter
8+ from typing import Any , Callable , List , Optional , Set , Tuple , Union
79
810import numpy as np
11+ from numpy import bool_ , float64 , ndarray , ufunc
912from scipy .linalg import norm
1013from sortedcontainers import SortedSet
1114
3033)
3134
3235
33- def _downdate (c , nans , depth ) :
36+ def _downdate (c : ndarray , nans : List [ int ] , depth : int ) -> ndarray :
3437 # This is algorithm 5 from the thesis of Pedro Gonnet.
3538 b = b_def [depth ].copy ()
3639 m = ns [depth ] - 1
@@ -48,7 +51,7 @@ def _downdate(c, nans, depth):
4851 return c
4952
5053
51- def _zero_nans (fx ) :
54+ def _zero_nans (fx : ndarray ) -> List [ int ] :
5255 """Caution: this function modifies fx."""
5356 nans = []
5457 for i in range (len (fx )):
@@ -58,7 +61,7 @@ def _zero_nans(fx):
5861 return nans
5962
6063
61- def _calc_coeffs (fx , depth ) :
64+ def _calc_coeffs (fx : ndarray , depth : int ) -> ndarray :
6265 """Caution: this function modifies fx."""
6366 nans = _zero_nans (fx )
6467 c_new = V_inv [depth ] @ fx
@@ -138,7 +141,9 @@ class _Interval:
138141 "removed" ,
139142 ]
140143
141- def __init__ (self , a , b , depth , rdepth ):
144+ def __init__ (
145+ self , a : Union [int , float64 ], b : Union [int , float64 ], depth : int , rdepth : int
146+ ) -> None :
142147 self .children = []
143148 self .data = {}
144149 self .a = a
@@ -150,15 +155,15 @@ def __init__(self, a, b, depth, rdepth):
150155 self .removed = False
151156
152157 @classmethod
153- def make_first (cls , a , b , depth = 2 ) :
158+ def make_first (cls , a : int , b : int , depth : int = 2 ) -> "_Interval" :
154159 ival = _Interval (a , b , depth , rdepth = 1 )
155160 ival .ndiv = 0
156161 ival .parent = None
157162 ival .err = sys .float_info .max # needed because inf/2 == inf
158163 return ival
159164
160165 @property
161- def T (self ):
166+ def T (self ) -> ndarray :
162167 """Get the correct shift matrix.
163168
164169 Should only be called on children of a split interval.
@@ -169,24 +174,24 @@ def T(self):
169174 assert left != right
170175 return T_left if left else T_right
171176
172- def refinement_complete (self , depth ) :
177+ def refinement_complete (self , depth : int ) -> bool :
173178 """The interval has all the y-values to calculate the intergral."""
174179 if len (self .data ) < ns [depth ]:
175180 return False
176181 return all (p in self .data for p in self .points (depth ))
177182
178- def points (self , depth = None ):
183+ def points (self , depth : Optional [ int ] = None ) -> ndarray :
179184 if depth is None :
180185 depth = self .depth
181186 a = self .a
182187 b = self .b
183188 return (a + b ) / 2 + (b - a ) * xi [depth ] / 2
184189
185- def refine (self ):
190+ def refine (self ) -> "_Interval" :
186191 self .depth += 1
187192 return self
188193
189- def split (self ):
194+ def split (self ) -> List [ "_Interval" ] :
190195 points = self .points ()
191196 m = points [len (points ) // 2 ]
192197 ivals = [
@@ -201,10 +206,10 @@ def split(self):
201206
202207 return ivals
203208
204- def calc_igral (self ):
209+ def calc_igral (self ) -> None :
205210 self .igral = (self .b - self .a ) * self .c [0 ] / sqrt (2 )
206211
207- def update_heuristic_err (self , value ) :
212+ def update_heuristic_err (self , value : Union [ float64 , float ]) -> None :
208213 """Sets the error of an interval using a heuristic (half the error of
209214 the parent) when the actual error cannot be calculated due to its
210215 parents not being finished yet. This error is propagated down to its
@@ -217,7 +222,7 @@ def update_heuristic_err(self, value):
217222 continue
218223 child .update_heuristic_err (value / 2 )
219224
220- def calc_err (self , c_old ) :
225+ def calc_err (self , c_old : ndarray ) -> float :
221226 c_new = self .c
222227 c_diff = np .zeros (max (len (c_old ), len (c_new )))
223228 c_diff [: len (c_old )] = c_old
@@ -229,7 +234,7 @@ def calc_err(self, c_old):
229234 child .update_heuristic_err (self .err / 2 )
230235 return c_diff
231236
232- def calc_ndiv (self ):
237+ def calc_ndiv (self ) -> None :
233238 div = self .parent .c00 and self .c00 / self .parent .c00 > 2
234239 self .ndiv += div
235240
@@ -240,15 +245,17 @@ def calc_ndiv(self):
240245 for child in self .children :
241246 child .update_ndiv_recursively ()
242247
243- def update_ndiv_recursively (self ):
248+ def update_ndiv_recursively (self ) -> None :
244249 self .ndiv += 1
245250 if self .ndiv > ndiv_max and 2 * self .ndiv > self .rdepth :
246251 raise DivergentIntegralError
247252
248253 for child in self .children :
249254 child .update_ndiv_recursively ()
250255
251- def complete_process (self , depth ):
256+ def complete_process (
257+ self , depth : int
258+ ) -> Union [Tuple [bool , bool ], Tuple [bool , bool_ ]]:
252259 """Calculate the integral contribution and error from this interval,
253260 and update the done leaves of all ancestor intervals."""
254261 assert self .depth_complete is None or self .depth_complete == depth - 1
@@ -323,7 +330,7 @@ def complete_process(self, depth):
323330
324331 return force_split , remove
325332
326- def __repr__ (self ):
333+ def __repr__ (self ) -> str :
327334 lst = [
328335 f"(a, b)=({ self .a :.5f} , { self .b :.5f} )" ,
329336 f"depth={ self .depth } " ,
@@ -335,7 +342,12 @@ def __repr__(self):
335342
336343
337344class IntegratorLearner (BaseLearner ):
338- def __init__ (self , function , bounds , tol ):
345+ def __init__ (
346+ self ,
347+ function : Union [partial , ufunc , Callable ],
348+ bounds : Tuple [int , int ],
349+ tol : float ,
350+ ) -> None :
339351 """
340352 Parameters
341353 ----------
@@ -384,10 +396,10 @@ def __init__(self, function, bounds, tol):
384396 self .first_ival = ival
385397
386398 @property
387- def approximating_intervals (self ):
399+ def approximating_intervals (self ) -> Set [ "_Interval" ] :
388400 return self .first_ival .done_leaves
389401
390- def tell (self , point , value ) :
402+ def tell (self , point : float64 , value : float64 ) -> None :
391403 if point not in self .x_mapping :
392404 raise ValueError (f"Point { point } doesn't belong to any interval" )
393405 self .data [point ] = value
@@ -423,7 +435,7 @@ def tell(self, point, value):
423435 def tell_pending (self ):
424436 pass
425437
426- def propagate_removed (self , ival ) :
438+ def propagate_removed (self , ival : "_Interval" ) -> None :
427439 def _propagate_removed_down (ival ):
428440 ival .removed = True
429441 self .ivals .discard (ival )
@@ -433,7 +445,7 @@ def _propagate_removed_down(ival):
433445
434446 _propagate_removed_down (ival )
435447
436- def add_ival (self , ival ) :
448+ def add_ival (self , ival : "_Interval" ) -> None :
437449 for x in ival .points ():
438450 # Update the mappings
439451 self .x_mapping [x ].add (ival )
@@ -444,15 +456,19 @@ def add_ival(self, ival):
444456 self ._stack .append (x )
445457 self .ivals .add (ival )
446458
447- def ask (self , n , tell_pending = True ):
459+ def ask (
460+ self , n : int , tell_pending : bool = True
461+ ) -> Union [Tuple [List [float64 ], List [float64 ]], Tuple [List [float64 ], List [float ]]]:
448462 """Choose points for learners."""
449463 if not tell_pending :
450464 with restore (self ):
451465 return self ._ask_and_tell_pending (n )
452466 else :
453467 return self ._ask_and_tell_pending (n )
454468
455- def _ask_and_tell_pending (self , n ):
469+ def _ask_and_tell_pending (
470+ self , n : int
471+ ) -> Union [Tuple [List [float64 ], List [float64 ]], Tuple [List [float64 ], List [float ]]]:
456472 points , loss_improvements = self .pop_from_stack (n )
457473 n_left = n - len (points )
458474 while n_left > 0 :
@@ -468,7 +484,13 @@ def _ask_and_tell_pending(self, n):
468484
469485 return points , loss_improvements
470486
471- def pop_from_stack (self , n ):
487+ def pop_from_stack (
488+ self , n : int
489+ ) -> Union [
490+ Tuple [List [float64 ], List [float64 ]],
491+ Tuple [List [Any ], List [Any ]],
492+ Tuple [List [float64 ], List [float ]],
493+ ]:
472494 points = self ._stack [:n ]
473495 self ._stack = self ._stack [n :]
474496 loss_improvements = [
@@ -479,7 +501,7 @@ def pop_from_stack(self, n):
479501 def remove_unfinished (self ):
480502 pass
481503
482- def _fill_stack (self ):
504+ def _fill_stack (self ) -> List [ float64 ] :
483505 # XXX: to-do if all the ivals have err=inf, take the interval
484506 # with the lowest rdepth and no children.
485507 force_split = bool (self .priority_split )
@@ -515,16 +537,16 @@ def _fill_stack(self):
515537 return self ._stack
516538
517539 @property
518- def npoints (self ):
540+ def npoints (self ) -> int :
519541 """Number of evaluated points."""
520542 return len (self .data )
521543
522544 @property
523- def igral (self ):
545+ def igral (self ) -> float64 :
524546 return sum (i .igral for i in self .approximating_intervals )
525547
526548 @property
527- def err (self ):
549+ def err (self ) -> float64 :
528550 if self .approximating_intervals :
529551 err = sum (i .err for i in self .approximating_intervals )
530552 if err > sys .float_info .max :
0 commit comments