22# Copyright 2017 Christoph Groth
33
44from collections import defaultdict
5- from fractions import Fraction as Frac
5+ from fractions import Fraction
6+ from functools import partial
7+ from typing import Callable , List , Tuple , Union
68
79import numpy as np
10+ from numpy import float64 , ndarray
811from numpy .testing import assert_allclose
912from scipy .linalg import inv , norm
1013
1114eps = np .spacing (1 )
1215
1316
14- def legendre (n ) :
17+ def legendre (n : int ) -> List [ List [ Fraction ]] :
1518 """Return the first n Legendre polynomials.
1619
1720 The polynomials have *standard* normalization, i.e.
1821 int_{-1}^1 dx L_n(x) L_m(x) = delta(m, n) * 2 / (2 * n + 1).
1922
2023 The return value is a list of list of fraction.Fraction instances.
2124 """
22- result = [[Frac (1 )], [Frac (0 ), Frac (1 )]]
25+ result = [[Fraction (1 )], [Fraction (0 ), Fraction (1 )]]
2326 if n <= 2 :
2427 return result [:n ]
2528 for i in range (2 , n ):
2629 # Use Bonnet's recursion formula.
27- new = (i + 1 ) * [Frac (0 )]
30+ new = (i + 1 ) * [Fraction (0 )]
2831 new [1 :] = (r * (2 * i - 1 ) for r in result [- 1 ])
2932 new [:- 2 ] = (n - r * (i - 1 ) for n , r in zip (new [:- 2 ], result [- 2 ]))
3033 new [:] = (n / i for n in new )
3134 result .append (new )
3235 return result
3336
3437
35- def newton (n ) :
38+ def newton (n : int ) -> ndarray :
3639 """Compute the monomial coefficients of the Newton polynomial over the
3740 nodes of the n-point Clenshaw-Curtis quadrature rule.
3841 """
@@ -89,7 +92,7 @@ def newton(n):
8992 return cf
9093
9194
92- def scalar_product (a , b ) :
95+ def scalar_product (a : List [ Fraction ] , b : List [ Fraction ]) -> Fraction :
9396 """Compute the polynomial scalar product int_-1^1 dx a(x) b(x).
9497
9598 The args must be sequences of polynomial coefficients. This
@@ -110,7 +113,7 @@ def scalar_product(a, b):
110113 return 2 * sum (c [i ] / (i + 1 ) for i in range (0 , lc , 2 ))
111114
112115
113- def calc_bdef (ns ) :
116+ def calc_bdef (ns : Tuple [ int , int , int , int ]) -> List [ ndarray ] :
114117 """Calculate the decompositions of Newton polynomials (over the nodes
115118 of the n-point Clenshaw-Curtis quadrature rule) in terms of
116119 Legandre polynomials.
@@ -123,7 +126,7 @@ def calc_bdef(ns):
123126 result = []
124127 for n in ns :
125128 poly = []
126- a = list (map (Frac , newton (n )))
129+ a = list (map (Fraction , newton (n )))
127130 for b in legs [: n + 1 ]:
128131 igral = scalar_product (a , b )
129132
@@ -145,7 +148,7 @@ def calc_bdef(ns):
145148b_def = calc_bdef (n )
146149
147150
148- def calc_V (xi , n ) :
151+ def calc_V (xi : ndarray , n : int ) -> ndarray :
149152 V = [np .ones (xi .shape ), xi .copy ()]
150153 for i in range (2 , n ):
151154 V .append ((2 * i - 1 ) / i * xi * V [- 1 ] - (i - 1 ) / i * V [- 2 ])
@@ -183,7 +186,7 @@ def calc_V(xi, n):
183186gamma = np .concatenate ([[0 , 0 ], np .sqrt (k [2 :] ** 2 / (4 * k [2 :] ** 2 - 1 ))])
184187
185188
186- def _downdate (c , nans , depth ) :
189+ def _downdate (c : ndarray , nans : List [ int ] , depth : int ) -> None :
187190 # This is algorithm 5 from the thesis of Pedro Gonnet.
188191 b = b_def [depth ].copy ()
189192 m = n [depth ] - 1
@@ -200,7 +203,7 @@ def _downdate(c, nans, depth):
200203 m -= 1
201204
202205
203- def _zero_nans (fx ) :
206+ def _zero_nans (fx : ndarray ) -> List [ int ] :
204207 nans = []
205208 for i in range (len (fx )):
206209 if not np .isfinite (fx [i ]):
@@ -209,7 +212,7 @@ def _zero_nans(fx):
209212 return nans
210213
211214
212- def _calc_coeffs (fx , depth ) :
215+ def _calc_coeffs (fx : ndarray , depth : int ) -> ndarray :
213216 """Caution: this function modifies fx."""
214217 nans = _zero_nans (fx )
215218 c_new = V_inv [depth ] @ fx
@@ -220,7 +223,7 @@ def _calc_coeffs(fx, depth):
220223
221224
222225class DivergentIntegralError (ValueError ):
223- def __init__ (self , msg , igral , err , nr_points ) :
226+ def __init__ (self , msg : str , igral : float64 , err : None , nr_points : int ) -> None :
224227 self .igral = igral
225228 self .err = err
226229 self .nr_points = nr_points
@@ -230,19 +233,23 @@ def __init__(self, msg, igral, err, nr_points):
230233class _Interval :
231234 __slots__ = ["a" , "b" , "c" , "fx" , "igral" , "err" , "depth" , "rdepth" , "ndiv" , "c00" ]
232235
233- def __init__ (self , a , b , depth , rdepth ):
236+ def __init__ (
237+ self , a : Union [int , float ], b : Union [int , float ], depth : int , rdepth : int
238+ ) -> None :
234239 self .a = a
235240 self .b = b
236241 self .depth = depth
237242 self .rdepth = rdepth
238243
239- def points (self ):
244+ def points (self ) -> ndarray :
240245 a = self .a
241246 b = self .b
242247 return (a + b ) / 2 + (b - a ) * xi [self .depth ] / 2
243248
244249 @classmethod
245- def make_first (cls , f , a , b , depth = 2 ):
250+ def make_first (
251+ cls , f : Union [partial , Callable ], a : int , b : int , depth : int = 2
252+ ) -> Tuple ["_Interval" , int ]:
246253 ival = _Interval (a , b , depth , 1 )
247254 fx = f (ival .points ())
248255 ival .c = _calc_coeffs (fx , depth )
@@ -251,7 +258,7 @@ def make_first(cls, f, a, b, depth=2):
251258 ival .ndiv = 0
252259 return ival , n [depth ]
253260
254- def calc_igral_and_err (self , c_old ) :
261+ def calc_igral_and_err (self , c_old : ndarray ) -> float :
255262 self .c = c_new = _calc_coeffs (self .fx , self .depth )
256263 c_diff = np .zeros (max (len (c_old ), len (c_new )))
257264 c_diff [: len (c_old )] = c_old
@@ -262,7 +269,9 @@ def calc_igral_and_err(self, c_old):
262269 self .err = w * c_diff
263270 return c_diff
264271
265- def split (self , f ):
272+ def split (
273+ self , f : Union [partial , Callable ]
274+ ) -> Union [Tuple [Tuple [float , float , float ], int ], Tuple [List ["_Interval" ], int ]]:
266275 m = (self .a + self .b ) / 2
267276 f_center = self .fx [(len (self .fx ) - 1 ) // 2 ]
268277
@@ -287,7 +296,7 @@ def split(self, f):
287296
288297 return ivals , nr_points
289298
290- def refine (self , f ) :
299+ def refine (self , f : Union [ partial , Callable ]) -> Tuple [ ndarray , bool , int ] :
291300 """Increase degree of interval."""
292301 self .depth = depth = self .depth + 1
293302 points = self .points ()
@@ -299,7 +308,9 @@ def refine(self, f):
299308 return points , split , n [depth ] - n [depth - 1 ]
300309
301310
302- def algorithm_4 (f , a , b , tol , N_loops = int (1e9 )):
311+ def algorithm_4 (
312+ f : Union [partial , Callable ], a : int , b : int , tol : float , N_loops : int = int (1e9 )
313+ ) -> Tuple [float64 , float , int , List ["_Interval" ]]:
303314 """ALGORITHM_4 evaluates an integral using adaptive quadrature. The
304315 algorithm uses Clenshaw-Curtis quadrature rules of increasing
305316 degree in each interval and bisects the interval if either the
@@ -403,37 +414,39 @@ def algorithm_4(f, a, b, tol, N_loops=int(1e9)):
403414 return igral , err , nr_points , ivals
404415
405416
406- ################ Tests ################
417+ # ############### Tests ################
407418
408419
409- def f0 (x ) :
420+ def f0 (x : Union [ float64 , ndarray ]) -> Union [ float64 , ndarray ] :
410421 return x * np .sin (1 / x ) * np .sqrt (abs (1 - x ))
411422
412423
413- def f7 (x ) :
424+ def f7 (x : Union [ float64 , ndarray ]) -> Union [ float64 , ndarray ] :
414425 return x ** - 0.5
415426
416427
417- def f24 (x ) :
428+ def f24 (x : Union [ float64 , ndarray ]) -> Union [ float64 , ndarray ] :
418429 return np .floor (np .exp (x ))
419430
420431
421- def f21 (x ) :
432+ def f21 (x : Union [ float64 , ndarray ]) -> Union [ float64 , ndarray ] :
422433 y = 0
423434 for i in range (1 , 4 ):
424435 y += 1 / np .cosh (20 ** i * (x - 2 * i / 10 ))
425436 return y
426437
427438
428- def f63 (x , alpha , beta ):
439+ def f63 (
440+ x : Union [float64 , ndarray ], alpha : float , beta : float
441+ ) -> Union [float64 , ndarray ]:
429442 return abs (x - beta ) ** alpha
430443
431444
432445def F63 (x , alpha , beta ):
433446 return (x - beta ) * abs (x - beta ) ** alpha / (alpha + 1 )
434447
435448
436- def fdiv (x ) :
449+ def fdiv (x : Union [ float64 , ndarray ]) -> Union [ float64 , ndarray ] :
437450 return abs (x - 0.987654321 ) ** - 1.1
438451
439452
@@ -461,7 +474,9 @@ def test_scalar_product(n=33):
461474 selection = [0 , 5 , 7 , n - 1 ]
462475 for i in selection :
463476 for j in selection :
464- assert scalar_product (legs [i ], legs [j ]) == ((i == j ) and Frac (2 , 2 * i + 1 ))
477+ assert scalar_product (legs [i ], legs [j ]) == (
478+ (i == j ) and Fraction (2 , 2 * i + 1 )
479+ )
465480
466481
467482def simple_newton (n ):
0 commit comments