Skip to content

Commit 975231c

Browse files
committed
Add tests for ewm and notimplementederrors
1 parent 0286d01 commit 975231c

2 files changed

Lines changed: 35 additions & 5 deletions

File tree

streamz/dataframe/aggregations.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,24 +148,27 @@ def initial(self, new):
148148
return new.iloc[:0]
149149

150150

151-
class EWMean(Mean):
151+
class EWMean(Aggregation):
152152
def __init__(self, com):
153153
self.com = com
154154
alpha = 1. / (1. + self.com)
155155
self.old_wt_factor = 1. - alpha
156156
self.new_wt = 1.
157157

158158
def on_new(self, acc, new):
159-
result, old_wt = acc
160-
for i in range(len(new)):
161-
result = ((old_wt * result) + (self.new_wt * new.iloc[i])) / (old_wt + self.new_wt)
159+
result, old_wt, is_first = acc
160+
for i in range(int(is_first), len(new)):
162161
old_wt *= self.old_wt_factor
162+
result = ((old_wt * result) + (self.new_wt * new.iloc[i])) / (old_wt + self.new_wt)
163163
old_wt += self.new_wt
164-
return (result, old_wt), result
164+
return (result, old_wt, False), result
165165

166166
def on_old(self, acc, old):
167167
pass
168168

169+
def initial(self, new):
170+
return new.iloc[:1], 1, True
171+
169172

170173
def diff_iloc(dfs, new, window=None):
171174
""" Emit new list of dfs and decayed data

streamz/dataframe/tests/test_dataframes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,33 @@ def test_expanding(func):
731731
assert_eq(result, expected)
732732

733733

734+
def test_ewm_mean():
735+
sdf = DataFrame(example=pd.DataFrame(columns=['x', 'y']))
736+
L = sdf.ewm(1).mean().stream.gather().sink_to_list()
737+
sdf.emit(pd.DataFrame({'x': [1.], 'y': [2.]}))
738+
sdf.emit(pd.DataFrame({'x': [2.], 'y': [3.]}))
739+
sdf.emit(pd.DataFrame({'x': [3.], 'y': [4.]}))
740+
result = pd.concat(L, ignore_index=True)
741+
742+
df = pd.DataFrame({'x': [1., 2., 3.], 'y': [2., 3., 4.]})
743+
expected = df.ewm(1).mean()
744+
assert_eq(result, expected)
745+
746+
747+
@pytest.mark.parametrize('func', [
748+
lambda x: x.sum(),
749+
lambda x: x.count(),
750+
lambda x: x.apply(lambda x: x),
751+
lambda x: x.full(),
752+
lambda x: x.var(),
753+
lambda x: x.std()
754+
], ids=["sum", "count", "apply", "full", "var", "std"])
755+
def test_ewm_notimplemented(func):
756+
sdf = DataFrame(example=pd.DataFrame(columns=['x', 'y']))
757+
with pytest.raises(NotImplementedError):
758+
func(sdf.ewm(1))
759+
760+
734761
@pytest.mark.parametrize('func', [
735762
lambda x: x.sum(),
736763
lambda x: x.mean(),

0 commit comments

Comments
 (0)