88import time
99import traceback
1010import warnings
11+ from _asyncio import Future , Task
12+ from concurrent .futures .process import ProcessPoolExecutor
1113from contextlib import suppress
14+ from typing import Any , Callable , List , Optional , Set , Tuple , Union
1215
16+ from distributed .cfexecutor import ClientExecutor
17+ from distributed .client import Client
18+ from ipyparallel .client .asyncresult import AsyncResult
19+ from ipyparallel .client .view import ViewExecutor
20+ from numpy import float64
21+
22+ from adaptive .learner .learner1D import Learner1D
23+ from adaptive .learner .learner2D import Learner2D
24+ from adaptive .learner .learnerND import LearnerND
1325from adaptive .notebook_integration import in_ipynb , live_info , live_plot
1426
1527try :
@@ -121,16 +133,16 @@ class BaseRunner(metaclass=abc.ABCMeta):
121133
122134 def __init__ (
123135 self ,
124- learner ,
125- goal ,
136+ learner : Union [ Learner1D , Learner2D , LearnerND ] ,
137+ goal : Callable ,
126138 * ,
127139 executor = None ,
128140 ntasks = None ,
129141 log = False ,
130142 shutdown_executor = False ,
131143 retries = 0 ,
132144 raise_if_retries_exceeded = True ,
133- ):
145+ ) -> None :
134146
135147 self .executor = _ensure_executor (executor )
136148 self .goal = goal
@@ -157,7 +169,7 @@ def __init__(
157169 self .to_retry = {}
158170 self .tracebacks = {}
159171
160- def _get_max_tasks (self ):
172+ def _get_max_tasks (self ) -> int :
161173 return self ._max_tasks or _get_ncores (self .executor )
162174
163175 def _do_raise (self , e , x ):
@@ -169,10 +181,10 @@ def _do_raise(self, e, x):
169181 ) from e
170182
171183 @property
172- def do_log (self ):
184+ def do_log (self ) -> bool :
173185 return self .log is not None
174186
175- def _ask (self , n ) :
187+ def _ask (self , n : int ) -> Any :
176188 points = [
177189 p for p in self .to_retry .keys () if p not in self .pending_points .values ()
178190 ][:n ]
@@ -206,7 +218,9 @@ def overhead(self):
206218 t_total = self .elapsed_time ()
207219 return (1 - t_function / t_total ) * 100
208220
209- def _process_futures (self , done_futs ):
221+ def _process_futures (
222+ self , done_futs : Union [Set [Future ], Set [Future ], Set [AsyncResult ], Set [Task ]]
223+ ) -> None :
210224 for fut in done_futs :
211225 x = self .pending_points .pop (fut )
212226 try :
@@ -227,7 +241,9 @@ def _process_futures(self, done_futs):
227241 self .log .append (("tell" , x , y ))
228242 self .learner .tell (x , y )
229243
230- def _get_futures (self ):
244+ def _get_futures (
245+ self ,
246+ ) -> Union [List [Task ], List [Future ], List [Future ], List [AsyncResult ]]:
231247 # Launch tasks to replace the ones that completed
232248 # on the last iteration, making sure to fill workers
233249 # that have started since the last iteration.
@@ -248,7 +264,7 @@ def _get_futures(self):
248264 futures = list (self .pending_points .keys ())
249265 return futures
250266
251- def _remove_unfinished (self ):
267+ def _remove_unfinished (self ) -> List [ Future ] :
252268 # remove points with 'None' values from the learner
253269 self .learner .remove_unfinished ()
254270 # cancel any outstanding tasks
@@ -257,7 +273,7 @@ def _remove_unfinished(self):
257273 fut .cancel ()
258274 return remaining
259275
260- def _cleanup (self ):
276+ def _cleanup (self ) -> None :
261277 if self .shutdown_executor :
262278 # XXX: temporary set wait=True for Python 3.7
263279 # see https://github.com/python-adaptive/adaptive/issues/156
@@ -347,16 +363,16 @@ class BlockingRunner(BaseRunner):
347363
348364 def __init__ (
349365 self ,
350- learner ,
351- goal ,
366+ learner : Union [ LearnerND , Learner2D , Learner1D ] ,
367+ goal : Callable ,
352368 * ,
353369 executor = None ,
354370 ntasks = None ,
355371 log = False ,
356372 shutdown_executor = False ,
357373 retries = 0 ,
358374 raise_if_retries_exceeded = True ,
359- ):
375+ ) -> None :
360376 if inspect .iscoroutinefunction (learner .function ):
361377 raise ValueError (
362378 "Coroutine functions can only be used " "with 'AsyncRunner'."
@@ -373,10 +389,12 @@ def __init__(
373389 )
374390 self ._run ()
375391
376- def _submit (self , x ):
392+ def _submit (
393+ self , x : Union [Tuple [int , int ], int , Tuple [float64 , float64 ], float ]
394+ ) -> Union [Future , AsyncResult ]:
377395 return self .executor .submit (self .learner .function , x )
378396
379- def _run (self ):
397+ def _run (self ) -> None :
380398 first_completed = concurrent .FIRST_COMPLETED
381399
382400 if self ._get_max_tasks () < 1 :
@@ -476,8 +494,8 @@ class AsyncRunner(BaseRunner):
476494
477495 def __init__ (
478496 self ,
479- learner ,
480- goal = None ,
497+ learner : Union [ Learner1D , Learner2D ] ,
498+ goal : Optional [ Callable ] = None ,
481499 * ,
482500 executor = None ,
483501 ntasks = None ,
@@ -486,7 +504,7 @@ def __init__(
486504 ioloop = None ,
487505 retries = 0 ,
488506 raise_if_retries_exceeded = True ,
489- ):
507+ ) -> None :
490508
491509 if goal is None :
492510
@@ -539,7 +557,9 @@ def goal(_):
539557 "'adaptive.notebook_extension()'"
540558 )
541559
542- def _submit (self , x ):
560+ def _submit (
561+ self , x : Union [Tuple [int , int ], int , Tuple [float64 , float64 ], float ]
562+ ) -> Union [Task , Future ]:
543563 ioloop = self .ioloop
544564 if inspect .iscoroutinefunction (self .learner .function ):
545565 return ioloop .create_task (self .learner .function (x ))
@@ -604,7 +624,7 @@ def live_info(self, *, update_interval=0.1):
604624 """
605625 return live_info (self , update_interval = update_interval )
606626
607- async def _run (self ):
627+ async def _run (self ) -> None :
608628 first_completed = asyncio .FIRST_COMPLETED
609629
610630 if self ._get_max_tasks () < 1 :
@@ -668,7 +688,7 @@ async def _saver(save_kwargs=save_kwargs, interval=interval):
668688Runner = AsyncRunner
669689
670690
671- def simple (learner , goal ) :
691+ def simple (learner : Any , goal : Callable ) -> None :
672692 """Run the learner until the goal is reached.
673693
674694 Requests a single point from the learner, evaluates
@@ -694,7 +714,16 @@ def simple(learner, goal):
694714 learner .tell (x , y )
695715
696716
697- def replay_log (learner , log ):
717+ def replay_log (
718+ learner : LearnerND ,
719+ log : List [
720+ Union [
721+ Tuple [str , int ],
722+ Tuple [str , Tuple [int , int , int ], float ],
723+ Tuple [str , Tuple [float , float , float ], float ],
724+ ]
725+ ],
726+ ) -> None :
698727 """Apply a sequence of method calls to a learner.
699728
700729 This is useful for debugging runners.
@@ -713,7 +742,7 @@ def replay_log(learner, log):
713742# --- Useful runner goals
714743
715744
716- def stop_after (* , seconds = 0 , minutes = 0 , hours = 0 ):
745+ def stop_after (* , seconds = 0 , minutes = 0 , hours = 0 ) -> Callable :
717746 """Stop a runner after a specified time.
718747
719748 For example, to specify a runner that should stop after
@@ -756,7 +785,7 @@ class SequentialExecutor(concurrent.Executor):
756785 This executor is mainly for testing.
757786 """
758787
759- def submit (self , fn , * args , ** kwargs ):
788+ def submit (self , fn : Callable , * args , ** kwargs ) -> Future :
760789 fut = concurrent .Future ()
761790 try :
762791 fut .set_result (fn (* args , ** kwargs ))
@@ -771,7 +800,9 @@ def shutdown(self, wait=True):
771800 pass
772801
773802
774- def _ensure_executor (executor ):
803+ def _ensure_executor (
804+ executor : Optional [Union [Client , Client , ProcessPoolExecutor , SequentialExecutor ]]
805+ ) -> Union [SequentialExecutor , ProcessPoolExecutor , ViewExecutor , ClientExecutor ]:
775806 if executor is None :
776807 executor = _default_executor (** _default_executor_kwargs )
777808
@@ -788,7 +819,9 @@ def _ensure_executor(executor):
788819 )
789820
790821
791- def _get_ncores (ex ):
822+ def _get_ncores (
823+ ex : Union [SequentialExecutor , ProcessPoolExecutor , ViewExecutor , ClientExecutor ]
824+ ) -> int :
792825 """Return the maximum number of cores that an executor can use."""
793826 if with_ipyparallel and isinstance (ex , ipyparallel .client .view .ViewExecutor ):
794827 return len (ex .view )
0 commit comments