|
9 | 9 | import numpy as np |
10 | 10 | from numpy import float64, int64 |
11 | 11 |
|
12 | | -from adaptive.learner.average_learner import AverageLearner |
13 | 12 | from adaptive.learner.base_learner import BaseLearner |
14 | | -from adaptive.learner.learner1D import Learner1D |
15 | | -from adaptive.learner.learner2D import Learner2D |
16 | | -from adaptive.learner.learnerND import LearnerND |
17 | | -from adaptive.learner.sequence_learner import SequenceLearner, _IgnoreFirstArgument |
18 | 13 | from adaptive.notebook_integration import ensure_holoviews |
19 | 14 | from adaptive.utils import cache_latest, named_product, restore |
20 | 15 |
|
21 | 16 |
|
22 | | -def dispatch( |
23 | | - child_functions: Union[List[Callable], List[partial], List[_IgnoreFirstArgument]], |
24 | | - arg: Any, |
25 | | -) -> Union[int, float64, float]: |
| 17 | +def dispatch(child_functions: List[Callable], arg: Any,) -> Union[int, float64, float]: |
26 | 18 | index, x = arg |
27 | 19 | return child_functions[index](x) |
28 | 20 |
|
@@ -79,17 +71,7 @@ class BalancingLearner(BaseLearner): |
79 | 71 | """ |
80 | 72 |
|
81 | 73 | def __init__( |
82 | | - self, |
83 | | - learners: Union[ |
84 | | - List[SequenceLearner], |
85 | | - List[AverageLearner], |
86 | | - List[Learner2D], |
87 | | - List[Learner1D], |
88 | | - List[LearnerND], |
89 | | - ], |
90 | | - *, |
91 | | - cdims=None, |
92 | | - strategy="loss_improvements" |
| 74 | + self, learners: List[BaseLearner], *, cdims=None, strategy="loss_improvements" |
93 | 75 | ) -> None: |
94 | 76 | self.learners = learners |
95 | 77 |
|
@@ -246,7 +228,7 @@ def _ask_and_tell_based_on_cycle( |
246 | 228 |
|
247 | 229 | return points, loss_improvements |
248 | 230 |
|
249 | | - def ask(self, n: int, tell_pending: bool = True) -> Any: |
| 231 | + def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[Any], List[float]]: |
250 | 232 | """Chose points for learners.""" |
251 | 233 | if n == 0: |
252 | 234 | return [], [] |
@@ -369,7 +351,9 @@ def remove_unfinished(self): |
369 | 351 | learner.remove_unfinished() |
370 | 352 |
|
371 | 353 | @classmethod |
372 | | - def from_product(cls, f, learner_type, learner_kwargs, combos): |
| 354 | + def from_product( |
| 355 | + cls, f, learner_type, learner_kwargs, combos |
| 356 | + ) -> "BalancingLearner": |
373 | 357 | """Create a `BalancingLearner` with learners of all combinations of |
374 | 358 | named variables’ values. The `cdims` will be set correctly, so calling |
375 | 359 | `learner.plot` will be a `holoviews.core.HoloMap` with the correct labels. |
|
0 commit comments