@@ -710,12 +710,34 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
710710 kw = {"point_names" : list ("xyz" )[: len (learner_kwargs ["bounds" ])]}
711711 else :
712712 kw = {}
713+
713714 learner = learner_type (generate_random_parametrization (f ), ** learner_kwargs )
714715 simple_run (learner , 100 )
715716 df = learner .to_dataframe (** kw )
716717 assert isinstance (df , pandas .DataFrame )
717- assert len (df ) == learner .npoints
718+ if learner_type is AverageLearner1D :
719+ assert len (df ) == learner .nsamples
720+ else :
721+ assert len (df ) == learner .npoints
722+
723+ # Add points from the DataFrame to a new empty learner
724+ learner2 = learner_type (generate_random_parametrization (f ), ** learner_kwargs )
725+
726+ if learner_type is Learner1D :
727+ learner2 .tell_many (df ["x" ], df ["y" ])
728+ elif learner_type is Learner2D :
729+ learner2 .tell_many (df [["x" , "y" ]].values , df ["z" ])
730+ elif learner_type is LearnerND :
731+ point_names = list (kw ["point_names" ])
732+ learner2 .tell_many (df [point_names ].values , df ["value" ])
733+ elif learner_type is AverageLearner :
734+ learner2 .tell_many (df ["seed" ].values , df ["y" ])
735+ elif learner_type is AverageLearner1D :
736+ learner2 .tell_many (df [["seed" , "x" ]].values , df ["y" ])
737+ else :
738+ raise NotImplementedError ()
718739
740+ # Test this for a learner in a BalancingLearner
719741 learners = [
720742 learner_type (generate_random_parametrization (f ), ** learner_kwargs )
721743 for _ in range (2 )
@@ -724,4 +746,8 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
724746 simple_run (learner , 100 )
725747 df = learner .to_dataframe (** kw )
726748 assert isinstance (df , pandas .DataFrame )
727- assert len (df ) == learner .npoints
749+
750+ if learner_type is not AverageLearner1D :
751+ assert len (df ) == learner .npoints
752+
753+ # TODO: Test this for a learner in a DataSaver
0 commit comments