55import pickle
66from contextlib import contextmanager
77from itertools import product
8+ from typing import Any , Callable , Iterator
89
910from atomicwrites import AtomicWriter
1011
12+ from adaptive .learner .base_learner import BaseLearner
13+
1114
1215def named_product (** items ):
1316 names = items .keys ()
@@ -16,7 +19,7 @@ def named_product(**items):
1619
1720
1821@contextmanager
19- def restore (* learners ):
22+ def restore (* learners ) -> Iterator [ None ] :
2023 states = [learner .__getstate__ () for learner in learners ]
2124 try :
2225 yield
@@ -25,7 +28,7 @@ def restore(*learners):
2528 learner .__setstate__ (state )
2629
2730
28- def cache_latest (f ) :
31+ def cache_latest (f : Callable ) -> Callable :
2932 """Cache the latest return value of the function and add it
3033 as 'self._cache[f.__name__]'."""
3134
@@ -40,7 +43,7 @@ def wrapper(*args, **kwargs):
4043 return wrapper
4144
4245
43- def save (fname , data , compress = True ):
46+ def save (fname : str , data : Any , compress : bool = True ) -> None :
4447 fname = os .path .expanduser (fname )
4548 dirname = os .path .dirname (fname )
4649 if dirname :
@@ -54,22 +57,22 @@ def save(fname, data, compress=True):
5457 f .write (blob )
5558
5659
57- def load (fname , compress = True ):
60+ def load (fname : str , compress : bool = True ):
5861 fname = os .path .expanduser (fname )
5962 _open = gzip .open if compress else open
6063 with _open (fname , "rb" ) as f :
6164 return pickle .load (f )
6265
6366
64- def copy_docstring_from (other ) :
67+ def copy_docstring_from (other : Callable ) -> Callable :
6568 def decorator (method ):
6669 return functools .wraps (other )(method )
6770
6871 return decorator
6972
7073
7174class _RequireAttrsABCMeta (abc .ABCMeta ):
72- def __call__ (self , * args , ** kwargs ):
75+ def __call__ (self , * args , ** kwargs ) -> BaseLearner :
7376 obj = super ().__call__ (* args , ** kwargs )
7477 for name , type_ in obj .__annotations__ .items ():
7578 try :
0 commit comments