1414import warnings
1515from contextlib import suppress
1616from datetime import datetime , timedelta
17- from typing import Any , Callable
17+ from typing import Any , Callable , Union
1818
1919import loky
2020
2121from adaptive import BalancingLearner , BaseLearner , IntegratorLearner , SequenceLearner
2222from adaptive .notebook_integration import in_ipynb , live_info , live_plot
2323
24+ try :
25+ from typing import TypeAlias
26+ except ModuleNotFoundError :
27+ # Python <3.10
28+ from typing_extensions import TypeAlias
29+
2430try :
2531 import ipyparallel
2632
6066 # and https://github.com/python-adaptive/adaptive/issues/301
6167 _default_executor = loky .get_reusable_executor
6268
69+ GoalTypes : TypeAlias = Union [
70+ Callable [[BaseLearner ], bool ], int , float , datetime , timedelta , None
71+ ]
72+
6373
6474class BaseRunner (metaclass = abc .ABCMeta ):
6575 r"""Base class for runners that use `concurrent.futures.Executors`.
@@ -120,8 +130,10 @@ class BaseRunner(metaclass=abc.ABCMeta):
120130 def __init__ (
121131 self ,
122132 learner ,
123- goal ,
124133 * ,
134+ goal : GoalTypes = None ,
135+ loss_goal : float | None = None ,
136+ npoints_goal : int | None = None ,
125137 executor = None ,
126138 ntasks = None ,
127139 log = False ,
@@ -132,7 +144,7 @@ def __init__(
132144 ):
133145
134146 self .executor = _ensure_executor (executor )
135- self .goal = auto_goal ( goal , learner , allow_running_forever )
147+ self .goal = _goal ( learner , goal , loss_goal , npoints_goal , allow_running_forever )
136148
137149 self ._max_tasks = ntasks
138150
@@ -376,8 +388,10 @@ class BlockingRunner(BaseRunner):
376388 def __init__ (
377389 self ,
378390 learner ,
379- goal ,
380391 * ,
392+ goal : GoalTypes = None ,
393+ loss_goal : float | None = None ,
394+ npoints_goal : int | None = None ,
381395 executor = None ,
382396 ntasks = None ,
383397 log = False ,
@@ -389,7 +403,9 @@ def __init__(
389403 raise ValueError ("Coroutine functions can only be used with 'AsyncRunner'." )
390404 super ().__init__ (
391405 learner ,
392- goal ,
406+ goal = goal ,
407+ loss_goal = loss_goal ,
408+ npoints_goal = npoints_goal ,
393409 executor = executor ,
394410 ntasks = ntasks ,
395411 log = log ,
@@ -508,8 +524,10 @@ class AsyncRunner(BaseRunner):
508524 def __init__ (
509525 self ,
510526 learner ,
511- goal = None ,
512527 * ,
528+ goal : GoalTypes = None ,
529+ loss_goal : float | None = None ,
530+ npoints_goal : int | None = None ,
513531 executor = None ,
514532 ntasks = None ,
515533 log = False ,
@@ -537,7 +555,9 @@ def __init__(
537555
538556 super ().__init__ (
539557 learner ,
540- goal ,
558+ goal = goal ,
559+ loss_goal = loss_goal ,
560+ npoints_goal = npoints_goal ,
541561 executor = executor ,
542562 ntasks = ntasks ,
543563 log = log ,
@@ -717,7 +737,13 @@ async def _saver():
717737Runner = AsyncRunner
718738
719739
720- def simple (learner , goal ):
740+ def simple (
741+ learner ,
742+ * ,
743+ goal : GoalTypes = None ,
744+ loss_goal : float | None = None ,
745+ npoints_goal : int | None = None ,
746+ ):
721747 """Run the learner until the goal is reached.
722748
723749 Requests a single point from the learner, evaluates
@@ -736,7 +762,7 @@ def simple(learner, goal):
736762 The end condition for the calculation. This function must take the
737763 learner as its sole argument, and return True if we should stop.
738764 """
739- goal = auto_goal ( goal , learner )
765+ goal = _goal ( learner , goal , loss_goal , npoints_goal , allow_running_forever = False )
740766 while not goal (learner ):
741767 xs , _ = learner .ask (1 )
742768 for x in xs :
@@ -871,14 +897,13 @@ def __call__(self, _):
871897 if self .start_time is None :
872898 self .start_time = datetime .now ()
873899 return datetime .now () - self .start_time > self .dt
874- elif isinstance (self .dt , datetime ):
900+ if isinstance (self .dt , datetime ):
875901 return datetime .now () > self .dt
876- else :
877- raise TypeError (f"`dt={ self .dt } ` is not a datetime or timedelta." )
902+ raise TypeError (f"`dt={ self .dt } ` is not a datetime or timedelta." )
878903
879904
880905def auto_goal (
881- goal : Callable [[ BaseLearner ], bool ] | int | float | datetime | timedelta | None ,
906+ goal : GoalTypes ,
882907 learner : BaseLearner ,
883908 allow_running_forever : bool = True ,
884909):
@@ -935,12 +960,28 @@ def auto_goal(
935960 return SequenceLearner .done
936961 if isinstance (learner , IntegratorLearner ):
937962 return IntegratorLearner .done
938- warnings .warn ("Goal is None which means the learners continue forever!" )
939- if allow_running_forever :
940- return lambda _ : False
941- else :
963+ if not allow_running_forever :
942964 raise ValueError (
943965 "Goal is None which means the learners"
944966 " continue forever and this is not allowed."
945967 )
968+ warnings .warn ("Goal is None which means the learners continue forever!" )
969+ return lambda _ : False
946970 raise ValueError ("Cannot determine goal from {goal}." )
971+
972+
973+ def _goal (
974+ learner : BaseLearner ,
975+ goal : GoalTypes ,
976+ loss_goal : float | None ,
977+ npoints_goal : int | None ,
978+ allow_running_forever : bool ,
979+ ):
980+ # goal, loss_goal, npoints_goal are mutually exclusive, only one can be not None
981+ if goal is not None and (loss_goal is not None or npoints_goal is not None ):
982+ raise ValueError ("Either goal, loss_goal, or npoints_goal can be specified." )
983+ if loss_goal is not None :
984+ goal = float (loss_goal )
985+ if npoints_goal is not None :
986+ goal = int (npoints_goal )
987+ return auto_goal (goal , learner , allow_running_forever )
0 commit comments