Skip to content

Commit 4e71119

Browse files
authored
Merge pull request #349 from nils-braun/bugfix/336-problematic-dicts
Wrap the items passed to `client.scatter` with a dict
2 parents caa6583 + 78132db commit 4e71119

2 files changed

Lines changed: 31 additions & 1 deletion

File tree

streamz/dask.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tornado import gen
66

77
from dask.compatibility import apply
8+
from dask.base import tokenize
89
from distributed.client import default_client
910

1011
from .core import Stream
@@ -103,7 +104,14 @@ def update(self, x, who=None, metadata=None):
103104
client = default_client()
104105

105106
self._retain_refs(metadata)
106-
future = yield client.scatter(x, asynchronous=True)
107+
# We need to make sure that x is treated as it is by dask
108+
# However, client.scatter works internally different for
109+
# lists and dicts. So we always use a dict here to be sure
110+
# we know the format exactly. The key will be taken as the
111+
# dask identifier of the data.
112+
tokenized_x = f"{type(x).__name__}-{tokenize(x)}"
113+
future_as_dict = yield client.scatter({tokenized_x: x}, asynchronous=True)
114+
future = future_as_dict[tokenized_x]
107115
f = yield self._emit(future, metadata=metadata)
108116
self._release_refs(metadata)
109117

streamz/tests/test_dask.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@ def test_map(c, s, a, b):
2828
assert all(isinstance(f, Future) for f in futures_L)
2929

3030

31+
@gen_cluster(client=True)
32+
def test_map_on_dict(c, s, a, b):
33+
# dask treats dicts differently, so we have to make sure
34+
# the user sees no difference in the streamz api.
35+
# Regression test against #336
36+
def add_to_dict(d):
37+
d["x"] = d["i"]
38+
return d
39+
40+
source = Stream(asynchronous=True)
41+
futures = source.scatter().map(add_to_dict)
42+
L = futures.gather().sink_to_list()
43+
44+
for i in range(5):
45+
yield source.emit({"i": i})
46+
47+
assert len(L) == 5
48+
for i, item in enumerate(sorted(L, key=lambda x: x["x"])):
49+
assert item["x"] == i
50+
assert item["i"] == i
51+
52+
3153
@gen_cluster(client=True)
3254
def test_scan(c, s, a, b):
3355
source = Stream(asynchronous=True)

0 commit comments

Comments
 (0)