1313import traceback
1414import warnings
1515from contextlib import suppress
16+ from datetime import datetime , timedelta
1617from typing import TYPE_CHECKING , Any , Callable
1718
1819import loky
@@ -129,10 +130,11 @@ def __init__(
129130 shutdown_executor = False ,
130131 retries = 0 ,
131132 raise_if_retries_exceeded = True ,
133+ allow_running_forever = False ,
132134 ):
133135
134136 self .executor = _ensure_executor (executor )
135- self .goal = goal
137+ self .goal = auto_goal ( goal , learner , allow_running_forever )
136138
137139 self ._max_tasks = ntasks
138140
@@ -396,6 +398,7 @@ def __init__(
396398 shutdown_executor = shutdown_executor ,
397399 retries = retries ,
398400 raise_if_retries_exceeded = raise_if_retries_exceeded ,
401+ allow_running_forever = False ,
399402 )
400403 self ._run ()
401404
@@ -518,11 +521,6 @@ def __init__(
518521 raise_if_retries_exceeded = True ,
519522 ):
520523
521- if goal is None :
522-
523- def goal (_ ):
524- return False
525-
526524 if (
527525 executor is None
528526 and _default_executor is concurrent .ProcessPoolExecutor
@@ -548,6 +546,7 @@ def goal(_):
548546 shutdown_executor = shutdown_executor ,
549547 retries = retries ,
550548 raise_if_retries_exceeded = raise_if_retries_exceeded ,
549+ allow_running_forever = True ,
551550 )
552551 self .ioloop = ioloop or asyncio .get_event_loop ()
553552 self .task = None
@@ -861,3 +860,89 @@ def _get_ncores(ex):
861860 return ex ._pool .size # not public API!
862861 else :
863862 raise TypeError (f"Cannot get number of cores for { ex .__class__ } " )
863+
864+
865+ class _TimeGoal :
866+ def __init__ (self , dt : timedelta | datetime ):
867+ self .dt = dt
868+ self .start_time = None
869+
870+ def __call__ (self , _ ):
871+ if isinstance (self .dt , timedelta ):
872+ if self .start_time is None :
873+ self .start_time = datetime .now ()
874+ return datetime .now () - self .start_time > self .dt
875+ elif isinstance (self .dt , datetime ):
876+ return datetime .now () > self .dt
877+ else :
878+ raise TypeError (f"{ self .dt = } is not a datetime or timedelta." )
879+
880+
881+ def auto_goal (
882+ goal : Callable [[BaseLearner ], bool ] | int | float | datetime | timedelta | None ,
883+ learner : BaseLearner ,
884+ allow_running_forever : bool = True ,
885+ ):
886+ """Extract a goal from the learners.
887+
888+ Parameters
889+ ----------
890+ goal
891+ The goal to extract. Can be a callable, an integer, a float, a datetime,
892+ a timedelta or None.
893+ If it is a callable, it is returned as is.
894+ If it is an integer, the goal is reached after that many points have been
895+ returned.
896+ If it is a float, the goal is reached when the learner has reached a loss
897+ less than that.
898+ If it is a datetime, the goal is reached when the current time is after the
899+ datetime.
900+ If it is a timedelta, the goal is reached when the current time is after
901+ the start time plus that timedelta.
902+ If it is None, and
903+ - the learner type is `adaptive.SequenceLearner`, it continues until
904+ it no more points to add
905+ - the learner type is `adaptive.Integrator`, it continues until the
906+ error is less than the tolerance.
907+ - otherwise, it continues forever, unless `allow_running_forever` is
908+ False, in which case it raises a ValueError.
909+ learner
910+ Learner for which to determine the goal.
911+ allow_running_forever
912+ If True, and the goal is None and the learner is not a SequenceLearner,
913+ then a goal that never stops is returned, otherwise an exception is raised.
914+
915+ Returns
916+ -------
917+ Callable[[adaptive.BaseLearner], bool]
918+ """
919+ from adaptive import BalancingLearner , IntegratorLearner , SequenceLearner
920+
921+ if callable (goal ):
922+ return goal
923+ if isinstance (goal , float ):
924+ return lambda learner : learner .loss () <= goal
925+ if isinstance (learner , BalancingLearner ):
926+ # Note that the float loss goal is more efficiently implemented in the
927+ # BalancingLearner itself. That is why the previous if statement is
928+ # above this one.
929+ goals = [auto_goal (goal , l , allow_running_forever ) for l in learner .learners ]
930+ return lambda learner : all (goal (l ) for l , goal in zip (learner .learners , goals ))
931+ if isinstance (goal , int ):
932+ return lambda learner : learner .npoints >= goal
933+ if isinstance (goal , (timedelta , datetime )):
934+ return _TimeGoal (goal )
935+ if goal is None :
936+ if isinstance (learner , SequenceLearner ):
937+ return SequenceLearner .done
938+ if isinstance (learner , IntegratorLearner ):
939+ return IntegratorLearner .done
940+ warnings .warn ("Goal is None which means the learners continue forever!" )
941+ if allow_running_forever :
942+ return lambda _ : False
943+ else :
944+ raise ValueError (
945+ "Goal is None which means the learners"
946+ " continue forever and this is not allowed."
947+ )
948+ raise ValueError ("Cannot determine goal from {goal}." )
0 commit comments