11from __future__ import absolute_import , division , print_function
22
3- from collections import deque
3+ from collections import deque , defaultdict
44from datetime import timedelta
55import functools
66import logging
@@ -941,6 +941,19 @@ def _check_end(self):
941941class partition (Stream ):
942942 """ Partition stream into tuples of equal size
943943
944+ Parameters
945+ ----------
946+ n: int
947+ Maximum partition size
948+ timeout: int or float, optional
949+ Number of seconds after which a partition will be emitted,
950+ even if its size is less than ``n``. If ``None`` (default),
951+ a partition will be emitted only when its size reaches ``n``.
952+ key: hashable or callable, optional
953+ Emit items with the same key together as a separate partition.
954+ If ``key`` is callable, partition will be identified by ``key(x)``,
955+ otherwise by ``x[key]``. Defaults to ``None``.
956+
944957 Examples
945958 --------
946959 >>> source = Stream()
@@ -950,30 +963,67 @@ class partition(Stream):
950963 (0, 1, 2)
951964 (3, 4, 5)
952965 (6, 7, 8)
966+
967+ >>> source = Stream()
968+ >>> source.partition(2, key=lambda x: x % 2).sink(print)
969+ >>> for i in range(4):
970+ ... source.emit(i)
971+ (0, 2)
972+ (1, 3)
973+
974+ >>> from time import sleep
975+ >>> source = Stream()
976+ >>> source.partition(5, timeout=1).sink(print)
977+ >>> for i in range(3):
978+ ... source.emit(i)
979+ >>> sleep(1)
980+ (0, 1, 2)
953981 """
954982 _graphviz_shape = 'diamond'
955983
956- def __init__ (self , upstream , n , ** kwargs ):
984+ def __init__ (self , upstream , n , timeout = None , key = None , ** kwargs ):
957985 self .n = n
958- self ._buffer = []
959- self .metadata_buffer = []
960- Stream .__init__ (self , upstream , ** kwargs )
986+ self ._timeout = timeout
987+ self ._key = key
988+ self ._buffer = defaultdict (lambda : [])
989+ self ._metadata_buffer = defaultdict (lambda : [])
990+ self ._callbacks = {}
991+ Stream .__init__ (self , upstream , ensure_io_loop = True , ** kwargs )
992+
993+ def _get_key (self , x ):
994+ if self ._key is None :
995+ return None
996+ if callable (self ._key ):
997+ return self ._key (x )
998+ return x [self ._key ]
999+
1000+ @gen .coroutine
1001+ def _flush (self , key ):
1002+ result , self ._buffer [key ] = self ._buffer [key ], []
1003+ metadata_result , self ._metadata_buffer [key ] = self ._metadata_buffer [key ], []
1004+ yield self ._emit (tuple (result ), list (metadata_result ))
1005+ self ._release_refs (metadata_result )
9611006
1007+ @gen .coroutine
9621008 def update (self , x , who = None , metadata = None ):
9631009 self ._retain_refs (metadata )
964- self ._buffer .append (x )
1010+ key = self ._get_key (x )
1011+ buffer = self ._buffer [key ]
1012+ metadata_buffer = self ._metadata_buffer [key ]
1013+ buffer .append (x )
9651014 if isinstance (metadata , list ):
966- self .metadata_buffer .extend (metadata )
967- else :
968- self .metadata_buffer .append (metadata )
969- if len (self ._buffer ) == self .n :
970- result , self ._buffer = self ._buffer , []
971- metadata_result , self .metadata_buffer = self .metadata_buffer , []
972- ret = self ._emit (tuple (result ), list (metadata_result ))
973- self ._release_refs (metadata_result )
974- return ret
1015+ metadata_buffer .extend (metadata )
9751016 else :
976- return []
1017+ metadata_buffer .append (metadata )
1018+ if len (buffer ) == self .n :
1019+ if self ._timeout is not None and self .n > 1 :
1020+ self ._callbacks [key ].cancel ()
1021+ yield self ._flush (key )
1022+ return
1023+ if len (buffer ) == 1 and self ._timeout is not None :
1024+ self ._callbacks [key ] = self .loop .call_later (
1025+ self ._timeout , self ._flush , key
1026+ )
9771027
9781028
9791029@Stream .register_api ()
0 commit comments