Skip to content

Commit df52232

Browse files
committed
add simple_run to the tests
1 parent aed824a commit df52232

1 file changed

Lines changed: 29 additions & 18 deletions

File tree

adaptive/tests/test_learners.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,23 @@ def uniform(a, b):
8787
return lambda: random.uniform(a, b)
8888

8989

90+
def simple_run(learner, n):
91+
def get_goal(learner):
92+
if hasattr(learner, "nsamples"):
93+
return lambda l: l.nsamples > n
94+
else:
95+
return lambda l: l.npoints > n
96+
97+
def goal():
98+
if isinstance(learner, BalancingLearner):
99+
return get_goal(learner.learners[0])
100+
elif isinstance(learner, DataSaver):
101+
return get_goal(learner.learner)
102+
return get_goal(learner)
103+
104+
simple(learner, goal())
105+
106+
90107
# Library of functions and associated learners.
91108

92109
learner_function_combos = collections.defaultdict(list)
@@ -262,7 +279,7 @@ def f(x):
262279
return [0, 1]
263280

264281
learner = learner_type(f, bounds=bounds)
265-
simple(learner, goal=lambda l: l.npoints > 10)
282+
simple_run(learner, 10)
266283

267284

268285
@run_with(Learner1D, Learner2D, LearnerND, SequenceLearner, AverageLearner1D)
@@ -275,7 +292,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
275292
f = generate_random_parametrization(f)
276293
learner = learner_type(f, **learner_kwargs)
277294
control = learner_type(f, **learner_kwargs)
278-
if learner_type is Learner1D:
295+
if learner_type in (Learner1D, AverageLearner1D):
279296
learner._recompute_losses_factor = 1
280297
control._recompute_losses_factor = 1
281298

@@ -377,7 +394,7 @@ def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
377394
learner = learner_type(f, **learner_kwargs)
378395
control = learner_type(f, **learner_kwargs)
379396

380-
if learner_type is Learner1D:
397+
if learner_type in (Learner1D, AverageLearner1D):
381398
learner._recompute_losses_factor = 1
382399
control._recompute_losses_factor = 1
383400

@@ -425,7 +442,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(
425442
assert sum(loss_improvements) < sum(
426443
learner.loss_per_triangle(learner.interpolator(scaled=True))
427444
)
428-
elif learner_type is Learner1D:
445+
elif learner_type in (Learner1D, AverageLearner1D):
429446
assert sum(loss_improvements) < sum(learner.losses.values())
430447
elif learner_type is AverageLearner:
431448
assert sum(loss_improvements) < learner.loss()
@@ -544,10 +561,10 @@ def test_saving(learner_type, f, learner_kwargs):
544561
f = generate_random_parametrization(f)
545562
learner = learner_type(f, **learner_kwargs)
546563
control = learner_type(f, **learner_kwargs)
547-
if learner_type is Learner1D:
564+
if learner_type in (Learner1D, AverageLearner1D):
548565
learner._recompute_losses_factor = 1
549566
control._recompute_losses_factor = 1
550-
simple(learner, lambda l: l.npoints > 100)
567+
simple_run(learner, 100)
551568
fd, path = tempfile.mkstemp()
552569
os.close(fd)
553570
try:
@@ -557,7 +574,7 @@ def test_saving(learner_type, f, learner_kwargs):
557574
np.testing.assert_almost_equal(learner.loss(), control.loss())
558575

559576
# Try if the control is runnable
560-
simple(control, lambda l: l.npoints > 200)
577+
simple_run(learner, 200)
561578
finally:
562579
os.remove(path)
563580

@@ -578,12 +595,12 @@ def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
578595
learner = BalancingLearner([learner_type(f, **learner_kwargs)])
579596
control = BalancingLearner([learner_type(f, **learner_kwargs)])
580597

581-
if learner_type is Learner1D:
598+
if learner_type in (Learner1D, AverageLearner1D):
582599
for l, c in zip(learner.learners, control.learners):
583600
l._recompute_losses_factor = 1
584601
c._recompute_losses_factor = 1
585602

586-
simple(learner, lambda l: l.learners[0].npoints > 100)
603+
simple_run(learner, 100)
587604
folder = tempfile.mkdtemp()
588605

589606
def fname(learner):
@@ -596,7 +613,7 @@ def fname(learner):
596613
np.testing.assert_almost_equal(learner.loss(), control.loss())
597614

598615
# Try if the control is runnable
599-
simple(control, lambda l: l.learners[0].npoints > 200)
616+
simple_run(control, 200)
600617
finally:
601618
shutil.rmtree(folder)
602619

@@ -622,13 +639,7 @@ def test_saving_with_datasaver(learner_type, f, learner_kwargs):
622639
learner.learner._recompute_losses_factor = 1
623640
control.learner._recompute_losses_factor = 1
624641

625-
def goal(n):
626-
if learner_type is AverageLearner1D:
627-
return lambda l: l.nsamples > n
628-
else:
629-
return lambda l: l.npoints > n
630-
631-
simple(learner, goal(100))
642+
simple_run(learner, 100)
632643
fd, path = tempfile.mkstemp()
633644
os.close(fd)
634645
try:
@@ -640,7 +651,7 @@ def goal(n):
640651
assert learner.extra_data == control.extra_data
641652

642653
# Try if the control is runnable
643-
simple(control, goal(200))
654+
simple_run(learner, 200)
644655
finally:
645656
os.remove(path)
646657

0 commit comments

Comments
 (0)