|
26 | 26 | except ModuleNotFoundError: |
27 | 27 | with_distributed = False |
28 | 28 |
|
| 29 | +try: |
| 30 | + import mpi4py.futures |
| 31 | + with_mpi4py = True |
| 32 | +except ModuleNotFoundError: |
| 33 | + with_mpi4py = False |
| 34 | + |
29 | 35 | with suppress(ModuleNotFoundError): |
30 | 36 | import uvloop |
31 | 37 | asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) |
@@ -66,7 +72,7 @@ class BaseRunner(metaclass=abc.ABCMeta): |
66 | 72 | the learner as its sole argument, and return True when we should |
67 | 73 | stop requesting more points. |
68 | 74 | executor : `concurrent.futures.Executor`, `distributed.Client`,\ |
69 | | - or `ipyparallel.Client`, optional |
| 75 | + `mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional |
70 | 76 | The executor in which to evaluate the function to be learned. |
71 | 77 | If not provided, a new `~concurrent.futures.ProcessPoolExecutor` |
72 | 78 | is used on Unix systems while on Windows a `distributed.Client` |
@@ -281,7 +287,7 @@ class BlockingRunner(BaseRunner): |
281 | 287 | the learner as its sole argument, and return True when we should |
282 | 288 | stop requesting more points. |
283 | 289 | executor : `concurrent.futures.Executor`, `distributed.Client`,\ |
284 | | - or `ipyparallel.Client`, optional |
| 290 | + `mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional |
285 | 291 | The executor in which to evaluate the function to be learned. |
286 | 292 | If not provided, a new `~concurrent.futures.ProcessPoolExecutor` |
287 | 293 | is used on Unix systems while on Windows a `distributed.Client` |
@@ -386,7 +392,7 @@ class AsyncRunner(BaseRunner): |
386 | 392 | stop requesting more points. If not provided, the runner will run |
387 | 393 | forever, or until ``self.task.cancel()`` is called. |
388 | 394 | executor : `concurrent.futures.Executor`, `distributed.Client`,\ |
389 | | - or `ipyparallel.Client`, optional |
| 395 | + `mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional |
390 | 396 | The executor in which to evaluate the function to be learned. |
391 | 397 | If not provided, a new `~concurrent.futures.ProcessPoolExecutor` |
392 | 398 | is used on Unix systems while on Windows a `distributed.Client` |
@@ -693,6 +699,9 @@ def _get_ncores(ex): |
693 | 699 | return 1 |
694 | 700 | elif with_distributed and isinstance(ex, distributed.cfexecutor.ClientExecutor): |
695 | 701 | return sum(n for n in ex._client.ncores().values()) |
| 702 | + elif with_mpi4py and isinstance(ex, mpi4py.futures.MPIPoolExecutor): |
| 703 | + ex.bootup() # wait until all workers are up and running |
| 704 | + return ex._pool.size # not public API! |
696 | 705 | else: |
697 | 706 | raise TypeError('Cannot get number of cores for {}' |
698 | 707 | .format(ex.__class__)) |
0 commit comments