44from contextlib import suppress
55from functools import partial
66from operator import itemgetter
7- from typing import Any , Callable , Dict , List , Set , Tuple , Union
7+ from typing import (
8+ Any ,
9+ Callable ,
10+ Dict ,
11+ List ,
12+ Literal ,
13+ Optional ,
14+ Sequence ,
15+ Set ,
16+ Tuple ,
17+ Union ,
18+ )
819
920import numpy as np
1021
@@ -18,6 +29,14 @@ def dispatch(child_functions: List[Callable], arg: Any) -> Union[Any]:
1829 return child_functions [index ](x )
1930
2031
32+ STRATEGY_TYPE = Literal ["loss_improvements" , "loss" , "npoints" , "cycle" ]
33+
34+ CDIMS_TYPE = Union [
35+ Sequence [Dict [str , Any ]],
36+ Tuple [Sequence [str ], Sequence [Tuple [Any , ...]]],
37+ ]
38+
39+
2140class BalancingLearner (BaseLearner ):
2241 r"""Choose the optimal points from a set of learners.
2342
@@ -70,7 +89,11 @@ class BalancingLearner(BaseLearner):
7089 """
7190
7291 def __init__ (
73- self , learners : List [BaseLearner ], * , cdims = None , strategy = "loss_improvements"
92+ self ,
93+ learners : List [BaseLearner ],
94+ * ,
95+ cdims : Optional [CDIMS_TYPE ] = None ,
96+ strategy : STRATEGY_TYPE = "loss_improvements"
7497 ) -> None :
7598 self .learners = learners
7699
@@ -89,7 +112,7 @@ def __init__(
89112 "A BalacingLearner can handle only one type" " of learners."
90113 )
91114
92- self .strategy = strategy
115+ self .strategy : STRATEGY_TYPE = strategy
93116
94117 @property
95118 def data (self ) -> Dict [Tuple [int , Any ], Any ]:
@@ -110,7 +133,7 @@ def npoints(self) -> int:
110133 return sum (l .npoints for l in self .learners )
111134
112135 @property
113- def strategy (self ):
136+ def strategy (self ) -> STRATEGY_TYPE :
114137 """Can be either 'loss_improvements' (default), 'loss', 'npoints', or
115138 'cycle'. The points that the `BalancingLearner` choses can be either
116139 based on: the best 'loss_improvements', the smallest total 'loss' of
@@ -121,7 +144,7 @@ def strategy(self):
121144 return self ._strategy
122145
123146 @strategy .setter
124- def strategy (self , strategy ) :
147+ def strategy (self , strategy : STRATEGY_TYPE ) -> None :
125148 self ._strategy = strategy
126149 if strategy == "loss_improvements" :
127150 self ._ask_and_tell = self ._ask_and_tell_based_on_loss_improvements
@@ -255,11 +278,16 @@ def _losses(self, real: bool = True) -> List[float]:
255278 return losses
256279
257280 @cache_latest
258- def loss (self , real : bool = True ) -> Union [ float ] :
281+ def loss (self , real : bool = True ) -> float :
259282 losses = self ._losses (real )
260283 return max (losses )
261284
262- def plot (self , cdims = None , plotter = None , dynamic = True ):
285+ def plot (
286+ self ,
287+ cdims : Optional [CDIMS_TYPE ] = None ,
288+ plotter : Optional [Callable [[BaseLearner ], Any ]] = None ,
289+ dynamic : bool = True ,
290+ ):
263291 """Returns a DynamicMap with sliders.
264292
265293 Parameters
@@ -332,14 +360,18 @@ def plot_function(*args):
332360 vals = {d .name : d .values for d in dm .dimensions () if d .values }
333361 return hv .HoloMap (dm .select (** vals ))
334362
335- def remove_unfinished (self ):
363+ def remove_unfinished (self ) -> None :
336364 """Remove uncomputed data from the learners."""
337365 for learner in self .learners :
338366 learner .remove_unfinished ()
339367
340368 @classmethod
341369 def from_product (
342- cls , f , learner_type , learner_kwargs , combos
370+ cls ,
371+ f ,
372+ learner_type : BaseLearner ,
373+ learner_kwargs : Dict [str , Any ],
374+ combos : Dict [str , Iterable [Any ]],
343375 ) -> "BalancingLearner" :
344376 """Create a `BalancingLearner` with learners of all combinations of
345377 named variables’ values. The `cdims` will be set correctly, so calling
@@ -387,7 +419,11 @@ def from_product(
387419 learners .append (learner )
388420 return cls (learners , cdims = arguments )
389421
390- def save (self , fname : Callable , compress : bool = True ) -> None :
422+ def save (
423+ self ,
424+ fname : Union [Callable [[BaseLearner ], str ], Sequence [str ]],
425+ compress : bool = True ,
426+ ) -> None :
391427 """Save the data of the child learners into pickle files
392428 in a directory.
393429
@@ -425,7 +461,11 @@ def save(self, fname: Callable, compress: bool = True) -> None:
425461 for l in self .learners :
426462 l .save (fname (l ), compress = compress )
427463
428- def load (self , fname : Callable , compress : bool = True ) -> None :
464+ def load (
465+ self ,
466+ fname : Union [Callable [[BaseLearner ], str ], Sequence [str ]],
467+ compress : bool = True ,
468+ ) -> None :
429469 """Load the data of the child learners from pickle files
430470 in a directory.
431471
@@ -449,20 +489,20 @@ def load(self, fname: Callable, compress: bool = True) -> None:
449489 for l in self .learners :
450490 l .load (fname (l ), compress = compress )
451491
452- def _get_data (self ):
492+ def _get_data (self ) -> List [ Any ] :
453493 return [l ._get_data () for l in self .learners ]
454494
455- def _set_data (self , data ):
495+ def _set_data (self , data : List [ Any ] ):
456496 for l , _data in zip (self .learners , data ):
457497 l ._set_data (_data )
458498
459- def __getstate__ (self ):
499+ def __getstate__ (self ) -> Tuple [ List [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] :
460500 return (
461501 self .learners ,
462502 self ._cdims_default ,
463503 self .strategy ,
464504 )
465505
466- def __setstate__ (self , state ):
506+ def __setstate__ (self , state : Tuple [ List [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] ):
467507 learners , cdims , strategy = state
468508 self .__init__ (learners , cdims = cdims , strategy = strategy )
0 commit comments