11import collections
2+ from typing import Callable , List , Tuple , Union
23
34import numpy as np
5+ from numpy import float64
46from skopt import Optimizer
57
68from adaptive .learner .base_learner import BaseLearner
@@ -23,13 +25,15 @@ class SKOptLearner(Optimizer, BaseLearner):
2325 Arguments to pass to ``skopt.Optimizer``.
2426 """
2527
26- def __init__ (self , function , ** kwargs ):
28+ def __init__ (self , function : Callable , ** kwargs ) -> None :
2729 self .function = function
2830 self .pending_points = set ()
2931 self .data = collections .OrderedDict ()
3032 super ().__init__ (** kwargs )
3133
32- def tell (self , x , y , fit = True ):
34+ def tell (
35+ self , x : Union [float64 , List [float64 ]], y : float64 , fit : bool = True
36+ ) -> None :
3337 if isinstance (x , collections .abc .Iterable ):
3438 self .pending_points .discard (tuple (x ))
3539 self .data [tuple (x )] = y
@@ -48,7 +52,7 @@ def remove_unfinished(self):
4852 pass
4953
5054 @cache_latest
51- def loss (self , real = True ):
55+ def loss (self , real : bool = True ) -> Union [ float64 , float ] :
5256 if not self .models :
5357 return np .inf
5458 else :
@@ -58,7 +62,14 @@ def loss(self, real=True):
5862 # estimator of loss, but it is the cheapest.
5963 return 1 - model .score (self .Xi , self .yi )
6064
61- def ask (self , n , tell_pending = True ):
65+ def ask (
66+ self , n : int , tell_pending : bool = True
67+ ) -> Union [
68+ Tuple [List [float64 ], List [float64 ]],
69+ Tuple [List [List [float64 ]], List [float64 ]],
70+ Tuple [List [List [float64 ]], List [float ]],
71+ Tuple [List [float64 ], List [float ]],
72+ ]:
6273 if not tell_pending :
6374 raise NotImplementedError (
6475 "Asking points is an irreversible "
@@ -72,7 +83,7 @@ def ask(self, n, tell_pending=True):
7283 return [p [0 ] for p in points ], [self .loss () / n ] * n
7384
7485 @property
75- def npoints (self ):
86+ def npoints (self ) -> int :
7687 """Number of evaluated points."""
7788 return len (self .Xi )
7889
0 commit comments