11# -*- coding: utf-8 -*-
22
3+ import pytest
4+
35from adaptive .learner import Learner1D , BalancingLearner
46from adaptive .runner import simple
57
@@ -24,32 +26,30 @@ def test_balancing_learner_loss_cache():
2426 assert bl .loss (real = True ) == real_loss
2527
2628
27- def test_distribute_first_points_over_learners ():
28- for strategy in ['loss' , 'loss_improvements' , 'npoints' ]:
29- learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
30- learner = BalancingLearner (learners , strategy = strategy )
31- points , _ = learner .ask (100 )
32- i_learner , xs = zip (* points )
33- # assert that are all learners in the suggested points
34- assert len (set (i_learner )) == len (learners )
35-
36-
37- def test_ask_0 ():
38- for strategy in ['loss' , 'loss_improvements' , 'npoints' ]:
39- learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
40- learner = BalancingLearner (learners , strategy = strategy )
41- points , _ = learner .ask (0 )
42- assert len (points ) == 0
43-
44-
45- def test_strategies ():
46- goals = {
47- 'loss' : lambda l : l .loss () < 0.1 ,
48- 'loss_improvements' : lambda l : l .loss () < 0.1 ,
49- 'npoints' : lambda bl : all (l .npoints > 10 for l in bl .learners )
50- }
51-
52- for strategy , goal in goals .items ():
53- learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
54- learner = BalancingLearner (learners , strategy = strategy )
55- simple (learner , goal = goal )
29+ @pytest .mark .parametrize ('strategy' , ['loss' , 'loss_improvements' , 'npoints' ])
30+ def test_distribute_first_points_over_learners (strategy ):
31+ learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
32+ learner = BalancingLearner (learners , strategy = strategy )
33+ points , _ = learner .ask (100 )
34+ i_learner , xs = zip (* points )
35+ # assert that are all learners in the suggested points
36+ assert len (set (i_learner )) == len (learners )
37+
38+
39+ @pytest .mark .parametrize ('strategy' , ['loss' , 'loss_improvements' , 'npoints' ])
40+ def test_ask_0 (strategy ):
41+ learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
42+ learner = BalancingLearner (learners , strategy = strategy )
43+ points , _ = learner .ask (0 )
44+ assert len (points ) == 0
45+
46+
47+ @pytest .mark .parametrize ('strategy, goal' , [
48+ ('loss' , lambda l : l .loss () < 0.1 ),
49+ ('loss_improvements' , lambda l : l .loss () < 0.1 ),
50+ ('npoints' , lambda bl : all (l .npoints > 10 for l in bl .learners )),
51+ ])
52+ def test_strategies (strategy , goal ):
53+ learners = [Learner1D (lambda x : x , bounds = (- 1 , 1 )) for i in range (10 )]
54+ learner = BalancingLearner (learners , strategy = strategy )
55+ simple (learner , goal = goal )
0 commit comments