11from functools import partial
22from operator import attrgetter
3+ from typing import Callable , List , Set , Union
34
45import numpy as np
56import pytest
67
7- from adaptive .learner import IntegratorLearner
88from adaptive .learner .integrator_coeffs import ns
9- from adaptive .learner .integrator_learner import DivergentIntegralError
9+ from adaptive .learner .integrator_learner import (
10+ DivergentIntegralError ,
11+ IntegratorLearner ,
12+ _Interval ,
13+ )
1014
1115from .algorithm_4 import DivergentIntegralError as A4DivergentIntegralError
1216from .algorithm_4 import algorithm_4 , f0 , f7 , f21 , f24 , f63 , fdiv
1317
1418eps = np .spacing (1 )
1519
1620
17- def run_integrator_learner (f , a , b , tol , n ):
21+ def run_integrator_learner (
22+ f : Union [partial , Callable ], a : int , b : int , tol : float , n : int
23+ ) -> IntegratorLearner :
1824 learner = IntegratorLearner (f , bounds = (a , b ), tol = tol )
1925 for _ in range (n ):
2026 points , _ = learner .ask (1 )
2127 learner .tell_many (points , map (learner .function , points ))
2228 return learner
2329
2430
25- def equal_ival (ival , other , * , verbose = False ):
31+ def equal_ival (ival : _Interval , other : _Interval , * , verbose = False ) -> bool :
2632 """Note: Implementing __eq__ breaks SortedContainers in some way."""
2733 if ival .depth_complete is None :
2834 if verbose :
@@ -42,7 +48,9 @@ def equal_ival(ival, other, *, verbose=False):
4248 return all (same_slots )
4349
4450
45- def equal_ivals (ivals , other , * , verbose = False ):
51+ def equal_ivals (
52+ ivals : Set [_Interval ], other : List [_Interval ], * , verbose = False
53+ ) -> bool :
4654 """Note: `other` is a list of ivals."""
4755 if len (ivals ) != len (other ):
4856 if verbose :
@@ -56,7 +64,7 @@ def equal_ivals(ivals, other, *, verbose=False):
5664 )
5765
5866
59- def same_ivals (f , a , b , tol ) :
67+ def same_ivals (f : Callable , a : int , b : int , tol : float ) -> bool :
6068 igral , err , n , ivals = algorithm_4 (f , a , b , tol )
6169
6270 learner = run_integrator_learner (f , a , b , tol , n )
@@ -71,15 +79,15 @@ def same_ivals(f, a, b, tol):
7179
7280# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
7381@pytest .mark .xfail
74- def test_that_gives_same_intervals_as_reference_implementation ():
82+ def test_that_gives_same_intervals_as_reference_implementation () -> None :
7583 for i , args in enumerate (
7684 [[f0 , 0 , 3 , 1e-5 ], [f7 , 0 , 1 , 1e-6 ], [f21 , 0 , 1 , 1e-3 ], [f24 , 0 , 3 , 1e-3 ]]
7785 ):
7886 assert same_ivals (* args ), f"Function { i } "
7987
8088
8189@pytest .mark .xfail
82- def test_machine_precision ():
90+ def test_machine_precision () -> None :
8391 f , a , b , tol = [partial (f63 , alpha = 0.987654321 , beta = 0.45 ), 0 , 1 , 1e-10 ]
8492 igral , err , n , ivals = algorithm_4 (f , a , b , tol )
8593
@@ -92,7 +100,7 @@ def test_machine_precision():
92100 assert equal_ivals (learner .ivals , ivals , verbose = True )
93101
94102
95- def test_machine_precision2 ():
103+ def test_machine_precision2 () -> None :
96104 f , a , b , tol = [partial (f63 , alpha = 0.987654321 , beta = 0.45 ), 0 , 1 , 1e-10 ]
97105 igral , err , n , ivals = algorithm_4 (f , a , b , tol )
98106
@@ -102,7 +110,7 @@ def test_machine_precision2():
102110 np .testing .assert_almost_equal (err , learner .err )
103111
104112
105- def test_divergence ():
113+ def test_divergence () -> None :
106114 """This function should raise a DivergentIntegralError."""
107115 f , a , b , tol = fdiv , 0 , 1 , 1e-6
108116 with pytest .raises (A4DivergentIntegralError ) as e :
@@ -114,22 +122,22 @@ def test_divergence():
114122 run_integrator_learner (f , a , b , tol , n )
115123
116124
117- def test_choosing_and_adding_points_one_by_one ():
125+ def test_choosing_and_adding_points_one_by_one () -> None :
118126 learner = IntegratorLearner (f24 , bounds = (0 , 3 ), tol = 1e-10 )
119127 for _ in range (1000 ):
120128 xs , _ = learner .ask (1 )
121129 for x in xs :
122130 learner .tell (x , learner .function (x ))
123131
124132
125- def test_choosing_and_adding_multiple_points_at_once ():
133+ def test_choosing_and_adding_multiple_points_at_once () -> None :
126134 learner = IntegratorLearner (f24 , bounds = (0 , 3 ), tol = 1e-10 )
127135 xs , _ = learner .ask (100 )
128136 for x in xs :
129137 learner .tell (x , learner .function (x ))
130138
131139
132- def test_adding_points_and_skip_one_point ():
140+ def test_adding_points_and_skip_one_point () -> None :
133141 learner = IntegratorLearner (f24 , bounds = (0 , 3 ), tol = 1e-10 )
134142 xs , _ = learner .ask (17 )
135143 skip_x = xs [1 ]
@@ -160,7 +168,7 @@ def test_adding_points_and_skip_one_point():
160168
161169# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
162170@pytest .mark .xfail
163- def test_tell_in_random_order (first_add_33 = False ):
171+ def test_tell_in_random_order (first_add_33 : bool = False ) -> None :
164172 from operator import attrgetter
165173 import random
166174
@@ -219,11 +227,11 @@ def test_tell_in_random_order(first_add_33=False):
219227
220228# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
221229@pytest .mark .xfail
222- def test_tell_in_random_order_first_add_33 ():
230+ def test_tell_in_random_order_first_add_33 () -> None :
223231 test_tell_in_random_order (first_add_33 = True )
224232
225233
226- def test_approximating_intervals ():
234+ def test_approximating_intervals () -> None :
227235 import random
228236
229237 learner = IntegratorLearner (f24 , bounds = (0 , 3 ), tol = 1e-10 )
@@ -252,7 +260,7 @@ def test_removed_choose_mutiple_points_at_once():
252260 assert list (learner .approximating_intervals )[0 ] == learner .first_ival
253261
254262
255- def test_removed_ask_one_by_one ():
263+ def test_removed_ask_one_by_one () -> None :
256264 with pytest .raises (RuntimeError ):
257265 # This test should raise because integrating np.exp should be done
258266 # after the 33th point
0 commit comments