Skip to content

Commit 9b47e03

Browse files
committed
add AverageLearner1D to test_learners.py
1 parent 55a1dd3 commit 9b47e03

1 file changed

Lines changed: 46 additions & 11 deletions

File tree

adaptive/tests/test_learners.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import adaptive
1818
from adaptive.learner import (
1919
AverageLearner,
20+
AverageLearner1D,
2021
BalancingLearner,
2122
DataSaver,
2223
IntegratorLearner,
@@ -146,6 +147,18 @@ def gaussian(n):
146147
return random.gauss(1, 1)
147148

148149

150+
@learn_with(AverageLearner1D, bounds=[-2, 2])
151+
def noisy_peak(
152+
x,
153+
sigma: uniform(1.5, 2.5),
154+
peak_width: uniform(0.04, 0.06),
155+
offset: uniform(-0.6, -0.3),
156+
):
157+
y = x ** 3 - x + 3 * peak_width ** 2 / (peak_width ** 2 + (x - offset) ** 2)
158+
noise = np.random.normal(0, sigma)
159+
return y + noise
160+
161+
149162
# Decorators for tests.
150163

151164

@@ -252,7 +265,7 @@ def f(x):
252265
simple(learner, goal=lambda l: l.npoints > 10)
253266

254267

255-
@run_with(Learner1D, Learner2D, LearnerND, SequenceLearner)
268+
@run_with(Learner1D, Learner2D, LearnerND, SequenceLearner, AverageLearner1D)
256269
def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
257270
"""Adding already existing data is an idempotent operation.
258271
@@ -299,7 +312,14 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
299312

300313
# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
301314
# but we xfail it now, as Learner2D will be deprecated anyway
302-
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner, SequenceLearner)
315+
@run_with(
316+
Learner1D,
317+
xfail(Learner2D),
318+
LearnerND,
319+
AverageLearner,
320+
AverageLearner1D,
321+
SequenceLearner,
322+
)
303323
def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
304324
"""Adding data for a point that was not returned by 'ask'."""
305325
# XXX: learner, control and bounds are not defined
@@ -341,7 +361,9 @@ def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
341361
assert set(pls) == set(cpls)
342362

343363

344-
@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND), AverageLearner)
364+
@run_with(
365+
Learner1D, xfail(Learner2D), xfail(LearnerND), AverageLearner, AverageLearner1D
366+
)
345367
def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
346368
"""The order of calls to 'tell' between calls to 'ask'
347369
is arbitrary.
@@ -383,7 +405,7 @@ def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
383405

384406
# XXX: the Learner2D fails with ~50% chance
385407
# see https://github.com/python-adaptive/adaptive/issues/55
386-
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner)
408+
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner, AverageLearner1D)
387409
def test_expected_loss_improvement_is_less_than_total_loss(
388410
learner_type, f, learner_kwargs
389411
):
@@ -411,7 +433,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(
411433

412434
# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
413435
# but we xfail it now, as Learner2D will be deprecated anyway
414-
@run_with(Learner1D, xfail(Learner2D), LearnerND)
436+
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner1D)
415437
def test_learner_performance_is_invariant_under_scaling(
416438
learner_type, f, learner_kwargs
417439
):
@@ -464,6 +486,7 @@ def test_learner_performance_is_invariant_under_scaling(
464486
Learner2D,
465487
LearnerND,
466488
AverageLearner,
489+
AverageLearner1D,
467490
SequenceLearner,
468491
with_all_loss_functions=False,
469492
)
@@ -498,16 +521,20 @@ def test_balancing_learner(learner_type, f, learner_kwargs):
498521
x = stash.pop()
499522
learner.tell(x, learner.function(x))
500523

501-
assert all(l.npoints > 5 for l in learner.learners), [
502-
l.npoints for l in learner.learners
503-
]
524+
if learner_type is AverageLearner1D:
525+
nsamples = [l.nsamples for l in learner.learners]
526+
assert all(l.nsamples > 5 for l in learner.learners), nsamples
527+
else:
528+
npoints = [l.npoints for l in learner.learners]
529+
assert all(l.npoints > 5 for l in learner.learners), npoints
504530

505531

506532
@run_with(
507533
Learner1D,
508534
Learner2D,
509535
LearnerND,
510536
AverageLearner,
537+
AverageLearner1D,
511538
maybe_skip(SKOptLearner),
512539
IntegratorLearner,
513540
SequenceLearner,
@@ -540,6 +567,7 @@ def test_saving(learner_type, f, learner_kwargs):
540567
Learner2D,
541568
LearnerND,
542569
AverageLearner,
570+
AverageLearner1D,
543571
maybe_skip(SKOptLearner),
544572
IntegratorLearner,
545573
SequenceLearner,
@@ -578,6 +606,7 @@ def fname(learner):
578606
Learner2D,
579607
LearnerND,
580608
AverageLearner,
609+
AverageLearner1D,
581610
maybe_skip(SKOptLearner),
582611
IntegratorLearner,
583612
with_all_loss_functions=False,
@@ -589,11 +618,17 @@ def test_saving_with_datasaver(learner_type, f, learner_kwargs):
589618
learner = DataSaver(learner_type(g, **learner_kwargs), arg_picker)
590619
control = DataSaver(learner_type(g, **learner_kwargs), arg_picker)
591620

592-
if learner_type is Learner1D:
621+
if learner_type in (Learner1D, AverageLearner1D):
593622
learner.learner._recompute_losses_factor = 1
594623
control.learner._recompute_losses_factor = 1
595624

596-
simple(learner, lambda l: l.npoints > 100)
625+
def goal(n):
626+
if learner_type is AverageLearner1D:
627+
return lambda l: l.nsamples > n
628+
else:
629+
return lambda l: l.npoints > n
630+
631+
simple(learner, goal(100))
597632
fd, path = tempfile.mkstemp()
598633
os.close(fd)
599634
try:
@@ -605,7 +640,7 @@ def test_saving_with_datasaver(learner_type, f, learner_kwargs):
605640
assert learner.extra_data == control.extra_data
606641

607642
# Try if the control is runnable
608-
simple(control, lambda l: l.npoints > 200)
643+
simple(control, goal(200))
609644
finally:
610645
os.remove(path)
611646

0 commit comments

Comments
 (0)