11import functools
22from collections import OrderedDict
3+ from operator import itemgetter
4+ from typing import Callable , Dict , Tuple , Union
35
6+ from adaptive .learner .average_learner import AverageLearner
47from adaptive .learner .base_learner import BaseLearner
8+ from adaptive .learner .learner1D import Learner1D
9+ from adaptive .learner .learner2D import Learner2D
10+ from adaptive .learner .learnerND import LearnerND
511from adaptive .utils import copy_docstring_from
612
713
@@ -25,13 +31,17 @@ class DataSaver:
2531 >>> learner = DataSaver(_learner, arg_picker=itemgetter('y'))
2632 """
2733
28- def __init__ (self , learner , arg_picker ):
34+ def __init__ (
35+ self ,
36+ learner : Union [Learner2D , Learner1D , LearnerND , AverageLearner ],
37+ arg_picker : itemgetter ,
38+ ) -> None :
2939 self .learner = learner
3040 self .extra_data = OrderedDict ()
3141 self .function = learner .function
3242 self .arg_picker = arg_picker
3343
34- def __getattr__ (self , attr ) :
44+ def __getattr__ (self , attr : str ) -> Union [ Callable , int ] :
3545 return getattr (self .learner , attr )
3646
3747 @copy_docstring_from (BaseLearner .tell )
@@ -44,10 +54,23 @@ def tell(self, x, result):
4454 def tell_pending (self , x ):
4555 self .learner .tell_pending (x )
4656
47- def _get_data (self ):
57+ def _get_data (
58+ self ,
59+ ) -> Union [
60+ Tuple [Dict [Union [int , float ], float ], OrderedDict ],
61+ Tuple [OrderedDict , OrderedDict ],
62+ Tuple [Tuple [Dict [int , float ], int , float , float ], OrderedDict ],
63+ ]:
4864 return self .learner ._get_data (), self .extra_data
4965
50- def _set_data (self , data ):
66+ def _set_data (
67+ self ,
68+ data : Union [
69+ Tuple [OrderedDict , OrderedDict ],
70+ Tuple [Dict [Union [int , float ], float ], OrderedDict ],
71+ Tuple [Tuple [Dict [int , float ], int , float , float ], OrderedDict ],
72+ ],
73+ ) -> None :
5174 learner_data , self .extra_data = data
5275 self .learner ._set_data (learner_data )
5376
0 commit comments