1+ from __future__ import annotations
2+
13import itertools
24import numbers
35from collections import defaultdict
46from collections .abc import Iterable
57from contextlib import suppress
68from functools import partial
79from operator import itemgetter
8- from typing import (
9- Any ,
10- Callable ,
11- Dict ,
12- List ,
13- Literal ,
14- Optional ,
15- Sequence ,
16- Set ,
17- Tuple ,
18- Union ,
19- )
10+ from typing import Any , Callable , Dict , Literal , Sequence , Tuple , Union
2011
2112import numpy as np
2213
2516from adaptive .utils import cache_latest , named_product , restore
2617
2718
28- def dispatch (child_functions : List [Callable ], arg : Any ) -> Union [ Any ] :
19+ def dispatch (child_functions : list [Callable ], arg : Any ) -> Any :
2920 index , x = arg
3021 return child_functions [index ](x )
3122
@@ -91,9 +82,9 @@ class BalancingLearner(BaseLearner):
9182
9283 def __init__ (
9384 self ,
94- learners : List [BaseLearner ],
85+ learners : list [BaseLearner ],
9586 * ,
96- cdims : Optional [ CDIMS_TYPE ] = None ,
87+ cdims : CDIMS_TYPE | None = None ,
9788 strategy : STRATEGY_TYPE = "loss_improvements" ,
9889 ) -> None :
9990 self .learners = learners
@@ -116,14 +107,14 @@ def __init__(
116107 self .strategy : STRATEGY_TYPE = strategy
117108
118109 @property
119- def data (self ) -> Dict [ Tuple [int , Any ], Any ]:
110+ def data (self ) -> dict [ tuple [int , Any ], Any ]:
120111 data = {}
121112 for i , l in enumerate (self .learners ):
122113 data .update ({(i , p ): v for p , v in l .data .items ()})
123114 return data
124115
125116 @property
126- def pending_points (self ) -> Set [ Tuple [int , Any ]]:
117+ def pending_points (self ) -> set [ tuple [int , Any ]]:
127118 pending_points = set ()
128119 for i , l in enumerate (self .learners ):
129120 pending_points .update ({(i , p ) for p in l .pending_points })
@@ -173,7 +164,7 @@ def strategy(self, strategy: STRATEGY_TYPE) -> None:
173164
174165 def _ask_and_tell_based_on_loss_improvements (
175166 self , n : int
176- ) -> Tuple [ List [ Tuple [int , Any ]], List [float ]]:
167+ ) -> tuple [ list [ tuple [int , Any ]], list [float ]]:
177168 selected = [] # tuples ((learner_index, point), loss_improvement)
178169 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
179170 for _ in range (n ):
@@ -198,7 +189,7 @@ def _ask_and_tell_based_on_loss_improvements(
198189
199190 def _ask_and_tell_based_on_loss (
200191 self , n : int
201- ) -> Tuple [ List [ Tuple [int , Any ]], List [float ]]:
192+ ) -> tuple [ list [ tuple [int , Any ]], list [float ]]:
202193 selected = [] # tuples ((learner_index, point), loss_improvement)
203194 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
204195 for _ in range (n ):
@@ -221,7 +212,7 @@ def _ask_and_tell_based_on_loss(
221212
222213 def _ask_and_tell_based_on_npoints (
223214 self , n : numbers .Integral
224- ) -> Tuple [ List [ Tuple [numbers .Integral , Any ]], List [float ]]:
215+ ) -> tuple [ list [ tuple [numbers .Integral , Any ]], list [float ]]:
225216 selected = [] # tuples ((learner_index, point), loss_improvement)
226217 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
227218 for _ in range (n ):
@@ -239,7 +230,7 @@ def _ask_and_tell_based_on_npoints(
239230
240231 def _ask_and_tell_based_on_cycle (
241232 self , n : int
242- ) -> Tuple [ List [ Tuple [numbers .Integral , Any ]], List [float ]]:
233+ ) -> tuple [ list [ tuple [numbers .Integral , Any ]], list [float ]]:
243234 points , loss_improvements = [], []
244235 for _ in range (n ):
245236 index = next (self ._cycle )
@@ -252,7 +243,7 @@ def _ask_and_tell_based_on_cycle(
252243
253244 def ask (
254245 self , n : int , tell_pending : bool = True
255- ) -> Tuple [ List [ Tuple [numbers .Integral , Any ]], List [float ]]:
246+ ) -> tuple [ list [ tuple [numbers .Integral , Any ]], list [float ]]:
256247 """Chose points for learners."""
257248 if n == 0 :
258249 return [], []
@@ -263,20 +254,20 @@ def ask(
263254 else :
264255 return self ._ask_and_tell (n )
265256
266- def tell (self , x : Tuple [numbers .Integral , Any ], y : Any ) -> None :
257+ def tell (self , x : tuple [numbers .Integral , Any ], y : Any ) -> None :
267258 index , x = x
268259 self ._ask_cache .pop (index , None )
269260 self ._loss .pop (index , None )
270261 self ._pending_loss .pop (index , None )
271262 self .learners [index ].tell (x , y )
272263
273- def tell_pending (self , x : Tuple [numbers .Integral , Any ]) -> None :
264+ def tell_pending (self , x : tuple [numbers .Integral , Any ]) -> None :
274265 index , x = x
275266 self ._ask_cache .pop (index , None )
276267 self ._loss .pop (index , None )
277268 self .learners [index ].tell_pending (x )
278269
279- def _losses (self , real : bool = True ) -> List [float ]:
270+ def _losses (self , real : bool = True ) -> list [float ]:
280271 losses = []
281272 loss_dict = self ._loss if real else self ._pending_loss
282273
@@ -294,8 +285,8 @@ def loss(self, real: bool = True) -> float:
294285
295286 def plot (
296287 self ,
297- cdims : Optional [ CDIMS_TYPE ] = None ,
298- plotter : Optional [ Callable [[BaseLearner ], Any ]] = None ,
288+ cdims : CDIMS_TYPE | None = None ,
289+ plotter : Callable [[BaseLearner ], Any ] | None = None ,
299290 dynamic : bool = True ,
300291 ):
301292 """Returns a DynamicMap with sliders.
@@ -380,9 +371,9 @@ def from_product(
380371 cls ,
381372 f ,
382373 learner_type : BaseLearner ,
383- learner_kwargs : Dict [str , Any ],
384- combos : Dict [str , Sequence [Any ]],
385- ) -> " BalancingLearner" :
374+ learner_kwargs : dict [str , Any ],
375+ combos : dict [str , Sequence [Any ]],
376+ ) -> BalancingLearner :
386377 """Create a `BalancingLearner` with learners of all combinations of
387378 named variables’ values. The `cdims` will be set correctly, so calling
388379 `learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -431,7 +422,7 @@ def from_product(
431422
432423 def save (
433424 self ,
434- fname : Union [ Callable [[BaseLearner ], str ], Sequence [str ] ],
425+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
435426 compress : bool = True ,
436427 ) -> None :
437428 """Save the data of the child learners into pickle files
@@ -473,7 +464,7 @@ def save(
473464
474465 def load (
475466 self ,
476- fname : Union [ Callable [[BaseLearner ], str ], Sequence [str ] ],
467+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
477468 compress : bool = True ,
478469 ) -> None :
479470 """Load the data of the child learners from pickle files
@@ -499,20 +490,20 @@ def load(
499490 for l in self .learners :
500491 l .load (fname (l ), compress = compress )
501492
502- def _get_data (self ) -> List [Any ]:
493+ def _get_data (self ) -> list [Any ]:
503494 return [l ._get_data () for l in self .learners ]
504495
505- def _set_data (self , data : List [Any ]):
496+ def _set_data (self , data : list [Any ]):
506497 for l , _data in zip (self .learners , data ):
507498 l ._set_data (_data )
508499
509- def __getstate__ (self ) -> Tuple [ List [BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ]:
500+ def __getstate__ (self ) -> tuple [ list [BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ]:
510501 return (
511502 self .learners ,
512503 self ._cdims_default ,
513504 self .strategy ,
514505 )
515506
516- def __setstate__ (self , state : Tuple [ List [BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ]):
507+ def __setstate__ (self , state : tuple [ list [BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ]):
517508 learners , cdims , strategy = state
518509 self .__init__ (learners , cdims = cdims , strategy = strategy )
0 commit comments