Skip to content

Commit fc379a0

Browse files
authored
Merge pull request #418 from martindurant/example
Allow example for dataframe to not be a dataframe
2 parents 0bd2f50 + b52dbca commit fc379a0

6 files changed

Lines changed: 14 additions & 16 deletions

File tree

streamz/collection.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,7 @@ def __init__(self, stream=None, example=None, stream_type=None):
177177
assert example is not None
178178
self.example = example
179179
if not isinstance(self.example, self._subtype):
180-
msg = ("For streaming type %s we expect an example of type %s. "
181-
"Got %s") % (type(self).__name__, self._subtype.__name__,
182-
str(self.example))
183-
raise TypeError(msg)
180+
self.example = self._subtype(example)
184181
assert isinstance(self.example, self._subtype)
185182
self.stream = stream or Stream()
186183
if stream_type:

streamz/core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from __future__ import absolute_import, division, print_function
21
from collections import deque, defaultdict
32
from datetime import timedelta
43
import functools

streamz/dataframe/core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import pandas as pd
77
import toolz
88

9-
from tornado import gen
10-
119
from ..collection import Streaming, _stream_types, OperatorMixin
1210
from ..sources import Source
1311
from ..utils import M
@@ -1049,7 +1047,7 @@ def stop(self):
10491047
async def _cb(interval, source, continue_):
10501048
last = pd.Timestamp.now()
10511049
while continue_[0]:
1052-
await gen.sleep(interval)
1050+
await asyncio.sleep(interval)
10531051
now = pd.Timestamp.now()
10541052
await asyncio.gather(*source._emit(dict(last=last, now=now)))
10551053
last = now

streamz/dataframe/tests/test_dataframe_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ def test_utils_get_base_frame_type_pandas():
1818
with pytest.raises(TypeError):
1919
get_base_frame_type("Index", is_index_like, df)
2020

21-
with pytest.raises(TypeError):
22-
get_base_frame_type("DataFrame", is_dataframe_like, df.x)
21+
# casts Series to DataFrame, if that's what we ask for
22+
assert pd.DataFrame == get_base_frame_type("DataFrame", is_dataframe_like, df.x)
2323
assert pd.Series == get_base_frame_type("Series", is_series_like, df.x)
2424
with pytest.raises(TypeError):
2525
get_base_frame_type("Index", is_index_like, df.x)
2626

27-
with pytest.raises(TypeError):
28-
get_base_frame_type("DataFrame", is_dataframe_like, df.index)
27+
# casts Series to DataFrame, if that's what we ask for
28+
assert pd.DataFrame == get_base_frame_type("DataFrame", is_dataframe_like, df.index)
2929
with pytest.raises(TypeError):
3030
get_base_frame_type("Series", is_series_like, df.index)
3131
assert issubclass(get_base_frame_type("Index", is_index_like, df.index), pd.Index)

streamz/dataframe/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ def get_base_frame_type(frame_name, is_frame_like, example=None):
3838
Returns the base type of streaming objects if type checks pass."""
3939
if example is None:
4040
raise TypeError("Missing required argument:'example'")
41-
if not is_frame_like(example):
41+
if is_frame_like is is_dataframe_like and not is_frame_like(example):
42+
import pandas as pd
43+
example = pd.DataFrame(example)
44+
45+
elif not is_frame_like(example):
4246
msg = "Streaming {0} expects an example of {0} like objects. Got: {1}."\
4347
.format(frame_name, example)
4448
raise TypeError(msg)

streamz/utils_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from contextlib import contextmanager
23
import logging
34
import os
@@ -116,11 +117,10 @@ def wait_for(predicate, timeout, fail_func=None, period=0.001):
116117
pytest.fail("condition not reached within %s seconds" % timeout)
117118

118119

119-
@gen.coroutine
120-
def await_for(predicate, timeout, fail_func=None, period=0.001):
120+
async def await_for(predicate, timeout, fail_func=None, period=0.001):
121121
deadline = time() + timeout
122122
while not predicate():
123-
yield gen.sleep(period)
123+
await asyncio.sleep(period)
124124
if time() > deadline: # pragma: no cover
125125
if fail_func is not None:
126126
fail_func()

0 commit comments

Comments
 (0)