Skip to content

Commit 8394906

Browse files
committed
add adaptive/tests/test_average_learner1d.py
1 parent df52232 commit 8394906

1 file changed

Lines changed: 44 additions & 0 deletions

File tree

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import pandas as pd
3+
from pandas.testing import assert_series_equal
4+
5+
from adaptive import AverageLearner1D
6+
from adaptive.tests.test_learners import (
7+
generate_random_parametrization,
8+
noisy_peak,
9+
simple_run,
10+
)
11+
12+
13+
def almost_equal_dicts(a, b):
14+
assert_series_equal(pd.Series(sorted(a.items())), pd.Series(sorted(b.items())))
15+
16+
17+
def test_copy_from():
18+
f = generate_random_parametrization(noisy_peak)
19+
learner = AverageLearner1D(f, bounds=[-2, 2])
20+
control = AverageLearner1D(f, bounds=[-2, 2])
21+
learner._recompute_losses_factor = 1
22+
control._recompute_losses_factor = 1
23+
simple_run(learner, 100)
24+
control.copy_from(learner)
25+
26+
almost_equal_dicts(learner.data, control.data)
27+
almost_equal_dicts(learner.error, control.error)
28+
almost_equal_dicts(learner.rescaled_error, control.rescaled_error)
29+
almost_equal_dicts(learner.neighbors, control.neighbors)
30+
almost_equal_dicts(learner.neighbors_combined, control.neighbors_combined)
31+
assert learner.npoints == control.npoints
32+
assert learner.nsamples == control.nsamples
33+
assert len(learner._data_samples) == len(control._data_samples)
34+
assert learner._data_samples.keys() == control._data_samples.keys()
35+
for k, v1 in learner._data_samples.items():
36+
v2 = control._data_samples[k]
37+
assert len(v1) == len(v2)
38+
np.testing.assert_almost_equal(np.sort(v1), np.sort(v2))
39+
40+
almost_equal_dicts(learner.losses, control.losses)
41+
np.testing.assert_almost_equal(learner.loss(), control.loss())
42+
43+
# Try if the control is runnable
44+
simple_run(control, 200)

0 commit comments

Comments
 (0)