Skip to content

Commit 4df73ca

Browse files
authored
Merge pull request #401 from wwoods/master2streamz
Test showing synchronous partition(timeout) + dask scatter is broken
2 parents b3c46eb + eed557c commit 4df73ca

6 files changed

Lines changed: 92 additions & 24 deletions

File tree

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,8 @@ max-line-length = 120
3939

4040
[bdist_wheel]
4141
universal=1
42+
43+
[tool:pytest]
44+
markers:
45+
network: Test requires an internet connection
46+
slow: Skipped unless --runslow passed

streamz/core.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
except ImportError:
2222
PollIOLoop = None # dropped in tornado 6.0
2323

24+
try:
25+
from distributed.client import default_client as _dask_default_client
26+
except ImportError: # pragma: no cover
27+
_dask_default_client = None
28+
2429
from collections.abc import Iterable
2530

2631
from threading import get_ident as get_thread_identity
@@ -42,6 +47,15 @@ def get_io_loop(asynchronous=None):
4247
if asynchronous:
4348
return IOLoop.current()
4449

50+
if _dask_default_client is not None:
51+
try:
52+
client = _dask_default_client()
53+
except ValueError:
54+
# No dask client found; continue
55+
pass
56+
else:
57+
return client.loop
58+
4559
if not _io_loops:
4660
loop = IOLoop()
4761
thread = threading.Thread(target=loop.start)
@@ -163,6 +177,7 @@ class Stream(object):
163177

164178
def __init__(self, upstream=None, upstreams=None, stream_name=None,
165179
loop=None, asynchronous=None, ensure_io_loop=False):
180+
self.name = stream_name
166181
self.downstreams = OrderedWeakrefSet()
167182
if upstreams is not None:
168183
self.upstreams = list(upstreams)
@@ -182,8 +197,6 @@ def __init__(self, upstream=None, upstreams=None, stream_name=None,
182197
if upstream:
183198
upstream.downstreams.add(self)
184199

185-
self.name = stream_name
186-
187200
def _set_loop(self, loop):
188201
self.loop = None
189202
if loop is not None:
@@ -343,8 +356,6 @@ def __str__(self):
343356
s = str(at)
344357
elif hasattr(at, '__name__'):
345358
s = getattr(self, m).__name__
346-
elif hasattr(at.__class__, '__name__'):
347-
s = getattr(self, m).__class__.__name__
348359
else:
349360
s = None
350361
if s:
@@ -994,7 +1005,8 @@ def __init__(self, upstream, n, timeout=None, key=None, **kwargs):
9941005
self._buffer = defaultdict(lambda: [])
9951006
self._metadata_buffer = defaultdict(lambda: [])
9961007
self._callbacks = {}
997-
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
1008+
kwargs["ensure_io_loop"] = True
1009+
Stream.__init__(self, upstream, **kwargs)
9981010

9991011
def _get_key(self, x):
10001012
if self._key is None:
@@ -1206,7 +1218,8 @@ def __init__(self, upstream, interval, **kwargs):
12061218
self.metadata_buffer = []
12071219
self.last = gen.moment
12081220

1209-
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
1221+
kwargs["ensure_io_loop"] = True
1222+
Stream.__init__(self, upstream, **kwargs)
12101223

12111224
self.loop.add_callback(self.cb)
12121225

@@ -1308,7 +1321,8 @@ def __init__(
13081321
self._buffer = {}
13091322
self._metadata_buffer = {}
13101323
self.last = gen.moment
1311-
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
1324+
kwargs["ensure_io_loop"] = True
1325+
Stream.__init__(self, upstream, **kwargs)
13121326
self.loop.add_callback(self.cb)
13131327

13141328
def _get_key(self, x):
@@ -1355,7 +1369,8 @@ def __init__(self, upstream, interval, **kwargs):
13551369
self.interval = convert_interval(interval)
13561370
self.queue = Queue()
13571371

1358-
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
1372+
kwargs["ensure_io_loop"] = True
1373+
Stream.__init__(self, upstream,**kwargs)
13591374

13601375
self.loop.add_callback(self.cb)
13611376

@@ -1393,7 +1408,8 @@ def __init__(self, upstream, interval, **kwargs):
13931408
self.interval = convert_interval(interval)
13941409
self.next = 0
13951410

1396-
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
1411+
kwargs["ensure_io_loop"] = True
1412+
Stream.__init__(self, upstream, **kwargs)
13971413

13981414
@gen.coroutine
13991415
def update(self, x, who=None, metadata=None):
@@ -1418,7 +1434,8 @@ class buffer(Stream):
14181434
def __init__(self, upstream, n, **kwargs):
14191435
self.queue = Queue(maxsize=n)
14201436

1421-
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
1437+
kwargs["ensure_io_loop"] = True
1438+
Stream.__init__(self, upstream, **kwargs)
14221439

14231440
self.loop.add_callback(self.cb)
14241441

@@ -1862,7 +1879,8 @@ def __init__(self, upstream, **kwargs):
18621879
self.next = []
18631880
self.next_metadata = None
18641881

1865-
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
1882+
kwargs["ensure_io_loop"] = True
1883+
Stream.__init__(self, upstream, **kwargs)
18661884

18671885
self.loop.add_callback(self.cb)
18681886

@@ -1918,7 +1936,8 @@ def __init__(self, upstream, topic, producer_config, **kwargs):
19181936
self.topic = topic
19191937
self.producer = ck.Producer(producer_config)
19201938

1921-
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
1939+
kwargs["ensure_io_loop"] = True
1940+
Stream.__init__(self, upstream, **kwargs)
19221941
self.stopped = False
19231942
self.polltime = 0.2
19241943
self.loop.add_callback(self.poll)

streamz/dask.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ class DaskStream(Stream):
3838
dask.distributed.Client
3939
"""
4040
def __init__(self, *args, **kwargs):
41-
if 'loop' not in kwargs:
42-
kwargs['loop'] = default_client().loop
43-
super(DaskStream, self).__init__(*args, **kwargs)
41+
kwargs["ensure_io_loop"] = True
42+
super().__init__(*args, **kwargs)
4443

4544

4645
@DaskStream.register_api()

streamz/dataframe/tests/test_dataframes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tornado import gen
1212

1313
from streamz import Stream
14-
from streamz.utils_test import gen_test
14+
from streamz.utils_test import gen_test, wait_for
1515
from streamz.dataframe import DataFrame, Series, DataFrames, Aggregation
1616
import streamz.dataframe as sd
1717
from streamz.dask import DaskStream
@@ -231,6 +231,8 @@ def test_index(stream):
231231
a.emit(df)
232232
a.emit(df)
233233

234+
wait_for(lambda: len(L) > 1, timeout=2, period=0.05)
235+
234236
assert_eq(L[0], df.index + 5)
235237
assert_eq(L[1], df.index + 5)
236238

streamz/tests/test_core.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,10 +403,11 @@ def test_timed_window():
403403

404404
@gen_test()
405405
def test_timed_window_ref_counts():
406-
source = Stream()
406+
source = Stream(asynchronous=True)
407407
_ = source.timed_window(0.01)
408408

409409
ref1 = RefCounter()
410+
assert str(ref1) == "<RefCounter count=0>"
410411
source.emit(1, metadata=[{'ref': ref1}])
411412
assert ref1.count == 1
412413
yield gen.sleep(0.05)
@@ -417,6 +418,13 @@ def test_timed_window_ref_counts():
417418
assert ref2.count == 1
418419

419420

421+
def test_mixed_async():
422+
s1 = Stream(asynchronous=False)
423+
with pytest.raises(ValueError):
424+
Stream(asynchronous=True, upstream=s1)
425+
426+
427+
420428
@gen_test()
421429
def test_timed_window_metadata():
422430
source = Stream()

streamz/tests/test_dask.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tornado import gen
99

1010
from streamz.dask import scatter
11-
from streamz import Stream
11+
from streamz import RefCounter, Stream
1212

1313
from distributed import Future, Client
1414
from distributed.utils import sync
@@ -51,6 +51,46 @@ def add_to_dict(d):
5151
assert item["i"] == i
5252

5353

54+
@gen_cluster(client=True)
55+
def test_partition_then_scatter_async(c, s, a, b):
56+
# Ensure partition w/ timeout before scatter works correctly for
57+
# asynchronous
58+
start = time.monotonic()
59+
source = Stream(asynchronous=True)
60+
61+
L = source.partition(2, timeout=.1).scatter().map(
62+
lambda x: [xx+1 for xx in x]).buffer(2).gather().flatten().sink_to_list()
63+
64+
rc = RefCounter(loop=source.loop)
65+
for i in range(3):
66+
yield source.emit(i, metadata=[{'ref': rc}])
67+
68+
while rc.count != 0 and time.monotonic() - start < 1.:
69+
yield gen.sleep(1e-2)
70+
71+
assert L == [1, 2, 3]
72+
73+
74+
def test_partition_then_scatter_sync(loop):
75+
# Ensure partition w/ timeout before scatter works correctly for synchronous
76+
with cluster() as (s, [a, b]):
77+
with Client(s['address'], loop=loop) as client: # noqa: F841
78+
start = time.monotonic()
79+
source = Stream()
80+
L = source.partition(2, timeout=.1).scatter().map(
81+
lambda x: [xx+1 for xx in x]).gather().flatten().sink_to_list()
82+
assert source.loop is client.loop
83+
84+
rc = RefCounter()
85+
for i in range(3):
86+
source.emit(i, metadata=[{'ref': rc}])
87+
88+
while rc.count != 0 and time.monotonic() - start < 2.:
89+
time.sleep(1e-2)
90+
91+
assert L == [1, 2, 3]
92+
93+
5494
@gen_cluster(client=True)
5595
def test_non_unique_emit(c, s, a, b):
5696
"""Regression for https://github.com/python-streamz/streams/issues/397
@@ -123,7 +163,6 @@ def test_accumulate(c, s, a, b):
123163
assert L[-1][1] == 3
124164

125165

126-
@pytest.mark.slow
127166
def test_sync(loop): # noqa: F811
128167
with cluster() as (s, [a, b]):
129168
with Client(s['address'], loop=loop) as client: # noqa: F841
@@ -140,7 +179,6 @@ def f():
140179
assert L == list(map(inc, range(10)))
141180

142181

143-
@pytest.mark.slow
144182
def test_sync_2(loop): # noqa: F811
145183
with cluster() as (s, [a, b]):
146184
with Client(s['address'], loop=loop): # noqa: F841
@@ -154,8 +192,7 @@ def test_sync_2(loop): # noqa: F811
154192
assert L == list(map(inc, range(10)))
155193

156194

157-
@pytest.mark.slow
158-
@gen_cluster(client=True, ncores=[('127.0.0.1', 1)] * 2)
195+
@gen_cluster(client=True, nthreads=[('127.0.0.1', 1)] * 2)
159196
def test_buffer(c, s, a, b):
160197
source = Stream(asynchronous=True)
161198
L = source.scatter().map(slowinc, delay=0.5).buffer(5).gather().sink_to_list()
@@ -181,7 +218,6 @@ def test_buffer(c, s, a, b):
181218
assert source.loop == c.loop
182219

183220

184-
@pytest.mark.slow
185221
def test_buffer_sync(loop): # noqa: F811
186222
with cluster() as (s, [a, b]):
187223
with Client(s['address'], loop=loop) as c: # noqa: F841
@@ -206,7 +242,6 @@ def test_buffer_sync(loop): # noqa: F811
206242

207243

208244
@pytest.mark.xfail(reason='')
209-
@pytest.mark.slow
210245
def test_stream_shares_client_loop(loop): # noqa: F811
211246
with cluster() as (s, [a, b]):
212247
with Client(s['address'], loop=loop) as client: # noqa: F841

0 commit comments

Comments
 (0)