Skip to content

Commit f060d9f

Browse files
author
wwoods
committed
Pick up dask's Client.loop when possible; fixes partition(timeout)
Originally, `Stream().partition(2, timeout=1.).scatter()` would result in an error about two different event loops running. This is because the `partition` command would require streamz to allocate an IOLoop, and then later get the `dask.distributed.Client` event loop as well. This patch introduces a guarded import in core to see if there is a dask client running, and if so, uses that loop instead of starting a new one. Also adds a regression test for async and sync behavior.
1 parent b82ca42 commit f060d9f

3 files changed

Lines changed: 56 additions & 5 deletions

File tree

streamz/core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
except ImportError:
2222
PollIOLoop = None # dropped in tornado 6.0
2323

24+
try:
25+
from distributed.client import default_client as _dask_default_client
26+
except ImportError:
27+
_dask_default_client = None
28+
2429
from collections.abc import Iterable
2530

2631
from .compatibility import get_thread_identity
@@ -42,6 +47,15 @@ def get_io_loop(asynchronous=None):
4247
if asynchronous:
4348
return IOLoop.current()
4449

50+
if _dask_default_client is not None:
51+
try:
52+
client = _dask_default_client()
53+
except ValueError:
54+
# No dask client found; continue
55+
pass
56+
else:
57+
return client.loop
58+
4559
if not _io_loops:
4660
loop = IOLoop()
4761
thread = threading.Thread(target=loop.start)

streamz/dask.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ class DaskStream(Stream):
3838
--------
3939
dask.distributed.Client
4040
"""
41-
def __init__(self, *args, **kwargs):
42-
if 'loop' not in kwargs:
43-
kwargs['loop'] = default_client().loop
44-
super(DaskStream, self).__init__(*args, **kwargs)
4541

4642

4743
@DaskStream.register_api()

streamz/tests/test_dask.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tornado import gen
88

99
from streamz.dask import scatter
10-
from streamz import Stream
10+
from streamz import RefCounter, Stream
1111

1212
from distributed import Future, Client
1313
from distributed.utils import sync
@@ -50,6 +50,47 @@ def add_to_dict(d):
5050
assert item["i"] == i
5151

5252

53+
@gen_cluster(client=True)
54+
def test_partition_then_scatter_async(c, s, a, b):
55+
# Ensure partition w/ timeout before scatter works correctly for
56+
# asynchronous
57+
start = time.monotonic()
58+
source = Stream(asynchronous=True)
59+
60+
L = source.partition(2, timeout=.1).scatter().map(
61+
lambda x: [xx+1 for xx in x]).buffer(2).gather().flatten().sink_to_list()
62+
63+
rc = RefCounter(loop=source.loop)
64+
for i in range(3):
65+
yield source.emit(i, metadata=[{'ref': rc}])
66+
67+
while rc.count != 0 and time.monotonic() - start < 1.:
68+
yield gen.sleep(1e-2)
69+
70+
assert L == [1, 2, 3]
71+
72+
73+
@pytest.mark.slow
74+
def test_partition_then_scatter_sync(loop):
75+
# Ensure partition w/ timeout before scatter works correctly for synchronous
76+
with cluster() as (s, [a, b]):
77+
with Client(s['address'], loop=loop) as client: # noqa: F841
78+
start = time.monotonic()
79+
source = Stream()
80+
L = source.partition(2, timeout=.1).scatter().map(
81+
lambda x: [xx+1 for xx in x]).gather().flatten().sink_to_list()
82+
assert source.loop is client.loop
83+
84+
rc = RefCounter()
85+
for i in range(3):
86+
source.emit(i, metadata=[{'ref': rc}])
87+
88+
while rc.count != 0 and time.monotonic() - start < 2.:
89+
time.sleep(1e-2)
90+
91+
assert L == [1, 2, 3]
92+
93+
5394
@gen_cluster(client=True)
5495
def test_scan(c, s, a, b):
5596
source = Stream(asynchronous=True)

0 commit comments

Comments
 (0)