Skip to content

Commit 6ecd2c9

Browse files
committed
Merge branch 'master' into refactor_sinks
2 parents a0c4c04 + a6e9111 commit 6ecd2c9

3 files changed

Lines changed: 133 additions & 17 deletions

File tree

streamz/core.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import absolute_import, division, print_function
22

3-
from collections import deque
3+
from collections import deque, defaultdict
44
from datetime import timedelta
55
import functools
66
import logging
@@ -941,6 +941,19 @@ def _check_end(self):
941941
class partition(Stream):
942942
""" Partition stream into tuples of equal size
943943
944+
Parameters
945+
----------
946+
n: int
947+
Maximum partition size
948+
timeout: int or float, optional
949+
Number of seconds after which a partition will be emitted,
950+
even if its size is less than ``n``. If ``None`` (default),
951+
a partition will be emitted only when its size reaches ``n``.
952+
key: hashable or callable, optional
953+
Emit items with the same key together as a separate partition.
954+
If ``key`` is callable, partition will be identified by ``key(x)``,
955+
otherwise by ``x[key]``. Defaults to ``None``.
956+
944957
Examples
945958
--------
946959
>>> source = Stream()
@@ -950,30 +963,67 @@ class partition(Stream):
950963
(0, 1, 2)
951964
(3, 4, 5)
952965
(6, 7, 8)
966+
967+
>>> source = Stream()
968+
>>> source.partition(2, key=lambda x: x % 2).sink(print)
969+
>>> for i in range(4):
970+
... source.emit(i)
971+
(0, 2)
972+
(1, 3)
973+
974+
>>> from time import sleep
975+
>>> source = Stream()
976+
>>> source.partition(5, timeout=1).sink(print)
977+
>>> for i in range(3):
978+
... source.emit(i)
979+
>>> sleep(1)
980+
(0, 1, 2)
953981
"""
954982
_graphviz_shape = 'diamond'
955983

956-
def __init__(self, upstream, n, **kwargs):
984+
def __init__(self, upstream, n, timeout=None, key=None, **kwargs):
957985
self.n = n
958-
self._buffer = []
959-
self.metadata_buffer = []
960-
Stream.__init__(self, upstream, **kwargs)
986+
self._timeout = timeout
987+
self._key = key
988+
self._buffer = defaultdict(lambda: [])
989+
self._metadata_buffer = defaultdict(lambda: [])
990+
self._callbacks = {}
991+
Stream.__init__(self, upstream, ensure_io_loop=True, **kwargs)
992+
993+
def _get_key(self, x):
994+
if self._key is None:
995+
return None
996+
if callable(self._key):
997+
return self._key(x)
998+
return x[self._key]
999+
1000+
@gen.coroutine
1001+
def _flush(self, key):
1002+
result, self._buffer[key] = self._buffer[key], []
1003+
metadata_result, self._metadata_buffer[key] = self._metadata_buffer[key], []
1004+
yield self._emit(tuple(result), list(metadata_result))
1005+
self._release_refs(metadata_result)
9611006

1007+
@gen.coroutine
9621008
def update(self, x, who=None, metadata=None):
9631009
self._retain_refs(metadata)
964-
self._buffer.append(x)
1010+
key = self._get_key(x)
1011+
buffer = self._buffer[key]
1012+
metadata_buffer = self._metadata_buffer[key]
1013+
buffer.append(x)
9651014
if isinstance(metadata, list):
966-
self.metadata_buffer.extend(metadata)
967-
else:
968-
self.metadata_buffer.append(metadata)
969-
if len(self._buffer) == self.n:
970-
result, self._buffer = self._buffer, []
971-
metadata_result, self.metadata_buffer = self.metadata_buffer, []
972-
ret = self._emit(tuple(result), list(metadata_result))
973-
self._release_refs(metadata_result)
974-
return ret
1015+
metadata_buffer.extend(metadata)
9751016
else:
976-
return []
1017+
metadata_buffer.append(metadata)
1018+
if len(buffer) == self.n:
1019+
if self._timeout is not None and self.n > 1:
1020+
self._callbacks[key].cancel()
1021+
yield self._flush(key)
1022+
return
1023+
if len(buffer) == 1 and self._timeout is not None:
1024+
self._callbacks[key] = self.loop.call_later(
1025+
self._timeout, self._flush, key
1026+
)
9771027

9781028

9791029
@Stream.register_api()

streamz/graph.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,14 @@ def readable_graph(graph):
131131
def to_graphviz(graph, **graph_attr):
132132
import graphviz
133133

134-
gvz = graphviz.Digraph(graph_attr=graph_attr)
134+
digraph_kwargs = {'name', 'comment', 'filename',
135+
'format', 'engine', 'encoding',
136+
'graph_attr', 'node_attr', 'edge_attr',
137+
'body', 'strict', 'directory'}
138+
if not digraph_kwargs.intersection(graph_attr):
139+
graph_attr = dict(graph_attr=graph_attr)
140+
141+
gvz = graphviz.Digraph(**graph_attr)
135142
for node, attrs in graph.nodes.items():
136143
gvz.node(node, **attrs)
137144
for edge, attrs in graph.edges().items():

streamz/tests/test_core.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,65 @@ def test_partition():
164164
assert L == [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]
165165

166166

167+
def test_partition_timeout():
168+
source = Stream()
169+
L = source.partition(10, timeout=0.01).sink_to_list()
170+
171+
for i in range(5):
172+
source.emit(i)
173+
174+
sleep(0.1)
175+
176+
assert L == [(0, 1, 2, 3, 4)]
177+
178+
179+
def test_partition_timeout_cancel():
180+
source = Stream()
181+
L = source.partition(3, timeout=0.1).sink_to_list()
182+
183+
for i in range(3):
184+
source.emit(i)
185+
186+
sleep(0.09)
187+
source.emit(3)
188+
sleep(0.02)
189+
190+
assert L == [(0, 1, 2)]
191+
192+
sleep(0.09)
193+
194+
assert L == [(0, 1, 2), (3,)]
195+
196+
197+
def test_partition_key():
198+
source = Stream()
199+
L = source.partition(2, key=0).sink_to_list()
200+
201+
for i in range(4):
202+
source.emit((i % 2, i))
203+
204+
assert L == [((0, 0), (0, 2)), ((1, 1), (1, 3))]
205+
206+
207+
def test_partition_key_callable():
208+
source = Stream()
209+
L = source.partition(2, key=lambda x: x % 2).sink_to_list()
210+
211+
for i in range(10):
212+
source.emit(i)
213+
214+
assert L == [(0, 2), (1, 3), (4, 6), (5, 7)]
215+
216+
217+
def test_partition_size_one():
218+
source = Stream()
219+
220+
source.partition(1, timeout=.01).sink(lambda x: None)
221+
222+
for i in range(10):
223+
source.emit(i)
224+
225+
167226
def test_sliding_window():
168227
source = Stream()
169228
L = source.sliding_window(2).sink_to_list()

0 commit comments

Comments
 (0)