44from contextlib import suppress
55from functools import partial
66from operator import itemgetter
7+ from typing import Any , Callable , Dict , List , Set , Tuple , Union
78
89import numpy as np
10+ from numpy import float64 , int64
911
12+ from adaptive .learner .average_learner import AverageLearner
1013from adaptive .learner .base_learner import BaseLearner
14+ from adaptive .learner .learner1D import Learner1D
15+ from adaptive .learner .learner2D import Learner2D
16+ from adaptive .learner .learnerND import LearnerND
17+ from adaptive .learner .sequence_learner import SequenceLearner , _IgnoreFirstArgument
1118from adaptive .notebook_integration import ensure_holoviews
1219from adaptive .utils import cache_latest , named_product , restore
1320
1421
15- def dispatch (child_functions , arg ):
22+ def dispatch (
23+ child_functions : Union [List [Callable ], List [partial ], List [_IgnoreFirstArgument ]],
24+ arg : Any ,
25+ ) -> Union [int , float64 , float ]:
1626 index , x = arg
1727 return child_functions [index ](x )
1828
@@ -68,7 +78,19 @@ class BalancingLearner(BaseLearner):
6878 behave in an undefined way. Change the `strategy` in that case.
6979 """
7080
71- def __init__ (self , learners , * , cdims = None , strategy = "loss_improvements" ):
81+ def __init__ (
82+ self ,
83+ learners : Union [
84+ List [SequenceLearner ],
85+ List [AverageLearner ],
86+ List [Learner2D ],
87+ List [Learner1D ],
88+ List [LearnerND ],
89+ ],
90+ * ,
91+ cdims = None ,
92+ strategy = "loss_improvements"
93+ ) -> None :
7294 self .learners = learners
7395
7496 # Naively we would make 'function' a method, but this causes problems
@@ -89,21 +111,21 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
89111 self .strategy = strategy
90112
91113 @property
92- def data (self ):
114+ def data (self ) -> Dict [ Tuple [ int , int ], int ] :
93115 data = {}
94116 for i , l in enumerate (self .learners ):
95117 data .update ({(i , p ): v for p , v in l .data .items ()})
96118 return data
97119
98120 @property
99- def pending_points (self ):
121+ def pending_points (self ) -> Set [ Tuple [ int , int ]] :
100122 pending_points = set ()
101123 for i , l in enumerate (self .learners ):
102124 pending_points .update ({(i , p ) for p in l .pending_points })
103125 return pending_points
104126
105127 @property
106- def npoints (self ):
128+ def npoints (self ) -> int :
107129 return sum (l .npoints for l in self .learners )
108130
109131 @property
@@ -135,7 +157,7 @@ def strategy(self, strategy):
135157 ' strategy="npoints", or strategy="cycle" is implemented.'
136158 )
137159
138- def _ask_and_tell_based_on_loss_improvements (self , n ) :
160+ def _ask_and_tell_based_on_loss_improvements (self , n : int ) -> Any :
139161 selected = [] # tuples ((learner_index, point), loss_improvement)
140162 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
141163 for _ in range (n ):
@@ -158,7 +180,13 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
158180 points , loss_improvements = map (list , zip (* selected ))
159181 return points , loss_improvements
160182
161- def _ask_and_tell_based_on_loss (self , n ):
183+ def _ask_and_tell_based_on_loss (
184+ self , n : int
185+ ) -> Union [
186+ Tuple [List [Tuple [int , float ]], List [float64 ]],
187+ Tuple [List [Union [Tuple [int , int ], Tuple [int , float ]]], List [float ]],
188+ Tuple [List [Tuple [int , int ]], List [float ]],
189+ ]:
162190 selected = [] # tuples ((learner_index, point), loss_improvement)
163191 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
164192 for _ in range (n ):
@@ -179,7 +207,13 @@ def _ask_and_tell_based_on_loss(self, n):
179207 points , loss_improvements = map (list , zip (* selected ))
180208 return points , loss_improvements
181209
182- def _ask_and_tell_based_on_npoints (self , n ):
210+ def _ask_and_tell_based_on_npoints (
211+ self , n : int
212+ ) -> Union [
213+ Tuple [List [Union [Tuple [int64 , int ], Tuple [int64 , float ]]], List [float ]],
214+ Tuple [List [Tuple [int64 , float ]], List [float64 ]],
215+ Tuple [List [Tuple [int64 , int ]], List [float ]],
216+ ]:
183217 selected = [] # tuples ((learner_index, point), loss_improvement)
184218 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
185219 for _ in range (n ):
@@ -195,7 +229,13 @@ def _ask_and_tell_based_on_npoints(self, n):
195229 points , loss_improvements = map (list , zip (* selected ))
196230 return points , loss_improvements
197231
198- def _ask_and_tell_based_on_cycle (self , n ):
232+ def _ask_and_tell_based_on_cycle (
233+ self , n : int
234+ ) -> Union [
235+ Tuple [List [Tuple [int , float ]], List [float64 ]],
236+ Tuple [List [Union [Tuple [int , int ], Tuple [int , float ]]], List [float ]],
237+ Tuple [List [Tuple [int , int ]], List [float ]],
238+ ]:
199239 points , loss_improvements = [], []
200240 for _ in range (n ):
201241 index = next (self ._cycle )
@@ -206,7 +246,7 @@ def _ask_and_tell_based_on_cycle(self, n):
206246
207247 return points , loss_improvements
208248
209- def ask (self , n , tell_pending = True ):
249+ def ask (self , n : int , tell_pending : bool = True ) -> Any :
210250 """Chose points for learners."""
211251 if n == 0 :
212252 return [], []
@@ -217,20 +257,24 @@ def ask(self, n, tell_pending=True):
217257 else :
218258 return self ._ask_and_tell (n )
219259
220- def tell (self , x , y ):
260+ def tell (
261+ self , x : Any , y : Union [int , float64 , float , Tuple [int , int ], Tuple [int64 , int ]]
262+ ) -> None :
221263 index , x = x
222264 self ._ask_cache .pop (index , None )
223265 self ._loss .pop (index , None )
224266 self ._pending_loss .pop (index , None )
225267 self .learners [index ].tell (x , y )
226268
227- def tell_pending (self , x ) :
269+ def tell_pending (self , x : Any ) -> None :
228270 index , x = x
229271 self ._ask_cache .pop (index , None )
230272 self ._loss .pop (index , None )
231273 self .learners [index ].tell_pending (x )
232274
233- def _losses (self , real = True ):
275+ def _losses (
276+ self , real : bool = True
277+ ) -> Union [List [float ], List [float64 ], List [Union [float , float64 ]]]:
234278 losses = []
235279 loss_dict = self ._loss if real else self ._pending_loss
236280
@@ -242,7 +286,7 @@ def _losses(self, real=True):
242286 return losses
243287
244288 @cache_latest
245- def loss (self , real = True ):
289+ def loss (self , real : bool = True ) -> Union [ float64 , float ] :
246290 losses = self ._losses (real )
247291 return max (losses )
248292
@@ -372,7 +416,7 @@ def from_product(cls, f, learner_type, learner_kwargs, combos):
372416 learners .append (learner )
373417 return cls (learners , cdims = arguments )
374418
375- def save (self , fname , compress = True ):
419+ def save (self , fname : Callable , compress : bool = True ) -> None :
376420 """Save the data of the child learners into pickle files
377421 in a directory.
378422
@@ -410,7 +454,7 @@ def save(self, fname, compress=True):
410454 for l in self .learners :
411455 l .save (fname (l ), compress = compress )
412456
413- def load (self , fname , compress = True ):
457+ def load (self , fname : Callable , compress : bool = True ) -> None :
414458 """Load the data of the child learners from pickle files
415459 in a directory.
416460
0 commit comments