|
8 | 8 | import sys |
9 | 9 | import threading |
10 | 10 | from time import time |
| 11 | +from typing import Any, Callable, Hashable, Union |
11 | 12 | import weakref |
12 | 13 |
|
13 | 14 | import toolz |
@@ -1070,6 +1071,107 @@ def update(self, x, who=None, metadata=None): |
1070 | 1071 | ) |
1071 | 1072 |
|
1072 | 1073 |
|
| 1074 | +@Stream.register_api() |
| 1075 | +class partition_unique(Stream): |
| 1076 | + """ |
| 1077 | + Partition stream elements into groups of equal size with unique keys only. |
| 1078 | +
|
| 1079 | + Parameters |
| 1080 | + ---------- |
| 1081 | + n: int |
| 1082 | + Number of (unique) elements to pass through as a group. |
| 1083 | + key: Union[Hashable, Callable[[Any], Hashable]] |
| 1084 | + Callable that accepts a stream element and returns a unique, hashable |
| 1085 | + representation of the incoming data (``key(x)``), or a hashable that gets |
| 1086 | + the corresponding value of a stream element (``x[key]``). For example, |
| 1087 | + ``key=lambda x: x["a"]`` would allow only elements with unique ``"a"`` values |
| 1088 | + to pass through. |
| 1089 | +
|
| 1090 | + .. note:: By default, we simply use the element object itself as the key, |
| 1091 | + so that object must be hashable. If that's not the case, a non-default |
| 1092 | + key must be provided. |
| 1093 | +
|
| 1094 | + keep: str |
| 1095 | + Which element to keep in the case that a unique key is already found |
| 1096 | + in the group. If "first", keep element from the first occurrence of a given |
| 1097 | + key; if "last", keep element from the most recent occurrence. Note that |
| 1098 | + relative ordering of *elements* is preserved in the data passed through, |
| 1099 | + and not ordering of *keys*. |
| 1100 | + **kwargs |
| 1101 | +
|
| 1102 | + Examples |
| 1103 | + -------- |
| 1104 | + >>> source = Stream() |
| 1105 | + >>> stream = source.partition_unique(n=3, keep="first").sink(print) |
| 1106 | + >>> eles = [1, 2, 1, 3, 1, 3, 3, 2] |
| 1107 | + >>> for ele in eles: |
| 1108 | + ... source.emit(ele) |
| 1109 | + (1, 2, 3) |
| 1110 | + (1, 3, 2) |
| 1111 | +
|
| 1112 | + >>> source = Stream() |
| 1113 | + >>> stream = source.partition_unique(n=3, keep="last").sink(print) |
| 1114 | + >>> eles = [1, 2, 1, 3, 1, 3, 3, 2] |
| 1115 | + >>> for ele in eles: |
| 1116 | + ... source.emit(ele) |
| 1117 | + (2, 1, 3) |
| 1118 | + (1, 3, 2) |
| 1119 | +
|
| 1120 | + >>> source = Stream() |
| 1121 | + >>> stream = source.partition_unique(n=3, key=lambda x: len(x), keep="last").sink(print) |
| 1122 | + >>> eles = ["f", "fo", "f", "foo", "f", "foo", "foo", "fo"] |
| 1123 | + >>> for ele in eles: |
| 1124 | + ... source.emit(ele) |
| 1125 | + ('fo', 'f', 'foo') |
| 1126 | + ('f', 'foo', 'fo') |
| 1127 | + """ |
| 1128 | + _graphviz_shape = "diamond" |
| 1129 | + |
| 1130 | + def __init__( |
| 1131 | + self, |
| 1132 | + upstream, |
| 1133 | + n: int, |
| 1134 | + key: Union[Hashable, Callable[[Any], Hashable]] = identity, |
| 1135 | + keep: str = "first", # Literal["first", "last"] |
| 1136 | + **kwargs |
| 1137 | + ): |
| 1138 | + self.n = n |
| 1139 | + self.key = key |
| 1140 | + self.keep = keep |
| 1141 | + self._buffer = {} |
| 1142 | + self._metadata_buffer = {} |
| 1143 | + Stream.__init__(self, upstream, **kwargs) |
| 1144 | + |
| 1145 | + def _get_key(self, x): |
| 1146 | + if callable(self.key): |
| 1147 | + return self.key(x) |
| 1148 | + else: |
| 1149 | + return x[self.key] |
| 1150 | + |
| 1151 | + def update(self, x, who=None, metadata=None): |
| 1152 | + self._retain_refs(metadata) |
| 1153 | + y = self._get_key(x) |
| 1154 | + if self.keep == "last": |
| 1155 | + # remove key if already present so that emitted value |
| 1156 | + # will reflect elements' actual relative ordering |
| 1157 | + self._buffer.pop(y, None) |
| 1158 | + self._metadata_buffer.pop(y, None) |
| 1159 | + self._buffer[y] = x |
| 1160 | + self._metadata_buffer[y] = metadata |
| 1161 | + else: # self.keep == "first" |
| 1162 | + if y not in self._buffer: |
| 1163 | + self._buffer[y] = x |
| 1164 | + self._metadata_buffer[y] = metadata |
| 1165 | + if len(self._buffer) == self.n: |
| 1166 | + result, self._buffer = tuple(self._buffer.values()), {} |
| 1167 | + metadata_result, self._metadata_buffer = list(self._metadata_buffer.values()), {} |
| 1168 | + ret = self._emit(result, metadata_result) |
| 1169 | + self._release_refs(metadata_result) |
| 1170 | + return ret |
| 1171 | + else: |
| 1172 | + return [] |
| 1173 | + |
| 1174 | + |
1073 | 1175 | @Stream.register_api() |
1074 | 1176 | class sliding_window(Stream): |
1075 | 1177 | """ Produce overlapping tuples of size n |
|
0 commit comments