@@ -87,6 +87,23 @@ def uniform(a, b):
8787 return lambda : random .uniform (a , b )
8888
8989
90+ def simple_run (learner , n ):
91+ def get_goal (learner ):
92+ if hasattr (learner , "nsamples" ):
93+ return lambda l : l .nsamples > n
94+ else :
95+ return lambda l : l .npoints > n
96+
97+ def goal ():
98+ if isinstance (learner , BalancingLearner ):
99+ return get_goal (learner .learners [0 ])
100+ elif isinstance (learner , DataSaver ):
101+ return get_goal (learner .learner )
102+ return get_goal (learner )
103+
104+ simple (learner , goal ())
105+
106+
90107# Library of functions and associated learners.
91108
92109learner_function_combos = collections .defaultdict (list )
@@ -262,7 +279,7 @@ def f(x):
262279 return [0 , 1 ]
263280
264281 learner = learner_type (f , bounds = bounds )
265- simple (learner , goal = lambda l : l . npoints > 10 )
282+ simple_run (learner , 10 )
266283
267284
268285@run_with (Learner1D , Learner2D , LearnerND , SequenceLearner , AverageLearner1D )
@@ -275,7 +292,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
275292 f = generate_random_parametrization (f )
276293 learner = learner_type (f , ** learner_kwargs )
277294 control = learner_type (f , ** learner_kwargs )
278- if learner_type is Learner1D :
295+ if learner_type in ( Learner1D , AverageLearner1D ) :
279296 learner ._recompute_losses_factor = 1
280297 control ._recompute_losses_factor = 1
281298
@@ -377,7 +394,7 @@ def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
377394 learner = learner_type (f , ** learner_kwargs )
378395 control = learner_type (f , ** learner_kwargs )
379396
380- if learner_type is Learner1D :
397+ if learner_type in ( Learner1D , AverageLearner1D ) :
381398 learner ._recompute_losses_factor = 1
382399 control ._recompute_losses_factor = 1
383400
@@ -425,7 +442,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(
425442 assert sum (loss_improvements ) < sum (
426443 learner .loss_per_triangle (learner .interpolator (scaled = True ))
427444 )
428- elif learner_type is Learner1D :
445+ elif learner_type in ( Learner1D , AverageLearner1D ) :
429446 assert sum (loss_improvements ) < sum (learner .losses .values ())
430447 elif learner_type is AverageLearner :
431448 assert sum (loss_improvements ) < learner .loss ()
@@ -544,10 +561,10 @@ def test_saving(learner_type, f, learner_kwargs):
544561 f = generate_random_parametrization (f )
545562 learner = learner_type (f , ** learner_kwargs )
546563 control = learner_type (f , ** learner_kwargs )
547- if learner_type is Learner1D :
564+ if learner_type in ( Learner1D , AverageLearner1D ) :
548565 learner ._recompute_losses_factor = 1
549566 control ._recompute_losses_factor = 1
550- simple (learner , lambda l : l . npoints > 100 )
567+ simple_run (learner , 100 )
551568 fd , path = tempfile .mkstemp ()
552569 os .close (fd )
553570 try :
@@ -557,7 +574,7 @@ def test_saving(learner_type, f, learner_kwargs):
557574 np .testing .assert_almost_equal (learner .loss (), control .loss ())
558575
559576 # Try if the control is runnable
560- simple ( control , lambda l : l . npoints > 200 )
577+ simple_run ( learner , 200 )
561578 finally :
562579 os .remove (path )
563580
@@ -578,12 +595,12 @@ def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
578595 learner = BalancingLearner ([learner_type (f , ** learner_kwargs )])
579596 control = BalancingLearner ([learner_type (f , ** learner_kwargs )])
580597
581- if learner_type is Learner1D :
598+ if learner_type in ( Learner1D , AverageLearner1D ) :
582599 for l , c in zip (learner .learners , control .learners ):
583600 l ._recompute_losses_factor = 1
584601 c ._recompute_losses_factor = 1
585602
586- simple (learner , lambda l : l . learners [ 0 ]. npoints > 100 )
603+ simple_run (learner , 100 )
587604 folder = tempfile .mkdtemp ()
588605
589606 def fname (learner ):
@@ -596,7 +613,7 @@ def fname(learner):
596613 np .testing .assert_almost_equal (learner .loss (), control .loss ())
597614
598615 # Try if the control is runnable
599- simple (control , lambda l : l . learners [ 0 ]. npoints > 200 )
616+ simple_run (control , 200 )
600617 finally :
601618 shutil .rmtree (folder )
602619
@@ -622,13 +639,7 @@ def test_saving_with_datasaver(learner_type, f, learner_kwargs):
622639 learner .learner ._recompute_losses_factor = 1
623640 control .learner ._recompute_losses_factor = 1
624641
625- def goal (n ):
626- if learner_type is AverageLearner1D :
627- return lambda l : l .nsamples > n
628- else :
629- return lambda l : l .npoints > n
630-
631- simple (learner , goal (100 ))
642+ simple_run (learner , 100 )
632643 fd , path = tempfile .mkstemp ()
633644 os .close (fd )
634645 try :
@@ -640,7 +651,7 @@ def goal(n):
640651 assert learner .extra_data == control .extra_data
641652
642653 # Try if the control is runnable
643- simple ( control , goal ( 200 ) )
654+ simple_run ( learner , 200 )
644655 finally :
645656 os .remove (path )
646657
0 commit comments