Skip to content

Commit ae84982

Browse files
authored
Merge pull request #412 from mroeschke/feature/ewm
Add Expanding and EWM.mean
2 parents 56562ac + 95c7036 commit ae84982

3 files changed

Lines changed: 237 additions & 5 deletions

File tree

streamz/dataframe/aggregations.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,28 @@ def initial(self, new):
148148
return new.iloc[:0]
149149

150150

151+
class EWMean(Aggregation):
152+
def __init__(self, com):
153+
self.com = com
154+
alpha = 1. / (1. + self.com)
155+
self.old_wt_factor = 1. - alpha
156+
self.new_wt = 1.
157+
158+
def on_new(self, acc, new):
159+
result, old_wt, is_first = acc
160+
for i in range(int(is_first), len(new)):
161+
old_wt *= self.old_wt_factor
162+
result = ((old_wt * result) + (self.new_wt * new.iloc[i])) / (old_wt + self.new_wt)
163+
old_wt += self.new_wt
164+
return (result, old_wt, False), result
165+
166+
def on_old(self, acc, old):
167+
pass
168+
169+
def initial(self, new):
170+
return new.iloc[:1], 1, True
171+
172+
151173
def diff_iloc(dfs, new, window=None):
152174
""" Emit new list of dfs and decayed data
153175
@@ -223,6 +245,13 @@ def diff_loc(dfs, new, window=None):
223245
return dfs, old
224246

225247

248+
def diff_expanding(dfs, new, window=None):
249+
dfs = deque(dfs)
250+
if len(new) > 0:
251+
dfs.append(new)
252+
return dfs, []
253+
254+
226255
def diff_align(dfs, groupers):
227256
""" Align groupers to newly-diffed dataframes
228257

streamz/dataframe/core.py

Lines changed: 139 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ def window(self, n=None, value=None, with_state=False, start=None):
170170
"""
171171
return Window(self, n=n, value=value, with_state=with_state, start=start)
172172

173+
def expanding(self, with_state=False, start=None):
174+
return Expanding(self, n=1, with_state=with_state, start=start)
175+
176+
def ewm(self, com=None, span=None, halflife=None, alpha=None, with_state=False, start=None):
177+
return EWM(self, n=1, com=com, span=span, halflife=halflife, alpha=alpha, with_state=with_state, start=start)
178+
173179
def _cumulative_aggregation(self, op):
174180
return self.accumulate_partitions(_cumulative_accumulator,
175181
returns_state=True,
@@ -531,18 +537,30 @@ def __init__(self, sdf, n=None, value=None, with_state=False, start=None):
531537

532538
def __getitem__(self, key):
533539
sdf = self.root[key]
534-
return Window(sdf, n=self.n, value=self.value, with_state=self.with_state, start=self.start)
540+
return type(self)(
541+
sdf,
542+
n=self.n,
543+
value=self.value,
544+
with_state=self.with_state,
545+
start=self.start
546+
)
535547

536548
def __getattr__(self, key):
537549
if key in self.root.columns or not len(self.root.columns):
538550
return self[key]
539551
else:
540-
raise AttributeError("Window has no attribute %r" % key)
552+
raise AttributeError(f"{type(self)} has no attribute {key}")
541553

542554
def map_partitions(self, func, *args, **kwargs):
543-
args2 = [a.root if isinstance(a, Window) else a for a in args]
555+
args2 = [a.root if isinstance(a, type(self)) else a for a in args]
544556
root = self.root.map_partitions(func, *args2, **kwargs)
545-
return Window(root, n=self.n, value=self.value, with_state=self.with_state, start=self.start)
557+
return type(self)(
558+
root,
559+
n=self.n,
560+
value=self.value,
561+
with_state=self.with_state,
562+
start=self.start
563+
)
546564

547565
@property
548566
def index(self):
@@ -561,7 +579,7 @@ def example(self):
561579
return self.root.example
562580

563581
def reset_index(self):
564-
return Window(self.root.reset_index(), n=self.n, value=self.value)
582+
return type(self)(self.root.reset_index(), n=self.n, value=self.value)
565583

566584
def aggregate(self, agg):
567585
if self.n is not None:
@@ -622,6 +640,122 @@ def groupby(self, other):
622640
self.with_state, self.start)
623641

624642

643+
class Expanding(Window):
644+
645+
def aggregate(self, agg):
646+
window = self.n
647+
diff = aggregations.diff_expanding
648+
return self.root.accumulate_partitions(aggregations.window_accumulator,
649+
diff=diff,
650+
window=window,
651+
agg=agg,
652+
start=self.start,
653+
returns_state=True,
654+
stream_type='updating',
655+
with_state=self.with_state)
656+
657+
def groupby(self, other):
658+
raise NotImplementedError
659+
660+
661+
class EWM(Expanding):
662+
663+
def __init__(
664+
self,
665+
sdf,
666+
n=1,
667+
value=None,
668+
with_state=False,
669+
start=None,
670+
com=None,
671+
span=None,
672+
halflife=None,
673+
alpha=None
674+
):
675+
super().__init__(sdf, n=n, value=value, with_state=with_state, start=start)
676+
self._com = self._get_com(com, span, halflife, alpha)
677+
self.com = com
678+
self.span = span
679+
self.alpha = alpha
680+
self.halflife = halflife
681+
682+
def __getitem__(self, key):
683+
sdf = self.root[key]
684+
return type(self)(
685+
sdf,
686+
n=self.n,
687+
value=self.value,
688+
with_state=self.with_state,
689+
start=self.start,
690+
com=self.com,
691+
span=self.span,
692+
halflife=self.halflife,
693+
alpha=self.alpha
694+
)
695+
696+
@staticmethod
697+
def _get_com(com, span, halflife, alpha):
698+
if sum(var is not None for var in (com, span, halflife, alpha)) > 1:
699+
raise ValueError("Can only provide one of `com`, `span`, `halflife`, `alpha`.")
700+
# Convert to center of mass; domain checks ensure 0 < alpha <= 1
701+
if com is not None:
702+
if com < 0:
703+
raise ValueError("com must satisfy: comass >= 0")
704+
elif span is not None:
705+
if span < 1:
706+
raise ValueError("span must satisfy: span >= 1")
707+
com = (span - 1) / 2
708+
elif halflife is not None:
709+
if halflife <= 0:
710+
raise ValueError("halflife must satisfy: halflife > 0")
711+
decay = 1 - np.exp(np.log(0.5) / halflife)
712+
com = 1 / decay - 1
713+
elif alpha is not None:
714+
if alpha <= 0 or alpha > 1:
715+
raise ValueError("alpha must satisfy: 0 < alpha <= 1")
716+
com = (1 - alpha) / alpha
717+
else:
718+
raise ValueError("Must pass one of com, span, halflife, or alpha")
719+
720+
return float(com)
721+
722+
def full(self):
723+
raise NotImplementedError
724+
725+
def apply(self, func):
726+
""" Apply an arbitrary function over each window of data """
727+
raise NotImplementedError
728+
729+
def sum(self):
730+
""" Sum elements within window """
731+
raise NotImplementedError
732+
733+
def count(self):
734+
""" Count elements within window """
735+
raise NotImplementedError
736+
737+
def mean(self):
738+
""" Average elements within window """
739+
return self.aggregate(aggregations.EWMean(self._com))
740+
741+
def var(self, ddof=1):
742+
""" Compute variance of elements within window """
743+
raise NotImplementedError
744+
745+
def std(self, ddof=1):
746+
""" Compute standard deviation of elements within window """
747+
raise NotImplementedError
748+
749+
@property
750+
def size(self):
751+
""" Number of elements within window """
752+
raise NotImplementedError
753+
754+
def value_counts(self):
755+
""" Count groups of elements within window """
756+
raise NotImplementedError
757+
758+
625759
def rolling_accumulator(acc, new, window=None, op=None,
626760
with_state=False, args=(), kwargs={}):
627761
if len(acc):

streamz/dataframe/tests/test_dataframes.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,75 @@ def test_windowing_n(func, n, getter):
709709
assert_eq(L[-1], func(getter(df).iloc[len(df) - n:] + 10))
710710

711711

712+
@pytest.mark.parametrize('func', [
713+
lambda x: x.sum(),
714+
lambda x: x.mean(),
715+
lambda x: x.count(),
716+
lambda x: x.var(ddof=1),
717+
lambda x: x.std(ddof=1),
718+
lambda x: x.var(ddof=0),
719+
], ids=["sum", "mean", "count", "var_1", "std", "var_0"])
720+
def test_expanding(func):
721+
df = pd.DataFrame({'x': [1.], 'y': [2.]})
722+
sdf = DataFrame(example=df)
723+
724+
L = func(sdf.expanding()).stream.gather().sink_to_list()
725+
726+
for i in range(5):
727+
sdf.emit(df)
728+
729+
result = pd.concat(L, axis=1).T.astype(float)
730+
expected = func(pd.concat([df] * 5, ignore_index=True).expanding())
731+
assert_eq(result, expected)
732+
733+
734+
def test_ewm_mean():
735+
sdf = DataFrame(example=pd.DataFrame(columns=['x', 'y']))
736+
L = sdf.ewm(1).mean().stream.gather().sink_to_list()
737+
sdf.emit(pd.DataFrame({'x': [1.], 'y': [2.]}))
738+
sdf.emit(pd.DataFrame({'x': [2.], 'y': [3.]}))
739+
sdf.emit(pd.DataFrame({'x': [3.], 'y': [4.]}))
740+
result = pd.concat(L, ignore_index=True)
741+
742+
df = pd.DataFrame({'x': [1., 2., 3.], 'y': [2., 3., 4.]})
743+
expected = df.ewm(1).mean()
744+
assert_eq(result, expected)
745+
746+
747+
def test_ewm_raise_multiple_arguments():
748+
sdf = DataFrame(example=pd.DataFrame(columns=['x', 'y']))
749+
with pytest.raises(ValueError, match="Can only provide one of"):
750+
sdf.ewm(com=1, halflife=1)
751+
752+
753+
def test_ewm_raise_no_argument():
754+
sdf = DataFrame(example=pd.DataFrame(columns=['x', 'y']))
755+
with pytest.raises(ValueError, match="Must pass one of"):
756+
sdf.ewm()
757+
758+
759+
@pytest.mark.parametrize("arg", ["com", "halflife", "alpha", "span"])
760+
def test_raise_invalid_argument(arg):
761+
sdf = DataFrame(example=pd.DataFrame(columns=['x', 'y']))
762+
param = {arg: -1}
763+
with pytest.raises(ValueError):
764+
sdf.ewm(**param)
765+
766+
767+
@pytest.mark.parametrize('func', [
768+
lambda x: x.sum(),
769+
lambda x: x.count(),
770+
lambda x: x.apply(lambda x: x),
771+
lambda x: x.full(),
772+
lambda x: x.var(),
773+
lambda x: x.std()
774+
], ids=["sum", "count", "apply", "full", "var", "std"])
775+
def test_ewm_notimplemented(func):
776+
sdf = DataFrame(example=pd.DataFrame(columns=['x', 'y']))
777+
with pytest.raises(NotImplementedError):
778+
func(sdf.ewm(1))
779+
780+
712781
@pytest.mark.parametrize('func', [
713782
lambda x: x.sum(),
714783
lambda x: x.mean(),

0 commit comments

Comments
 (0)