Skip to content

Commit dda82ff

Browse files
committed
tests, check class of node before registering
1 parent b52b92c commit dda82ff

3 files changed

Lines changed: 49 additions & 17 deletions

File tree

streamz/core.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def register_api(cls, modifier=identity, attribute_name=None):
273273
>>> Stream().foo(...) # this works now
274274
275275
It attaches the callable as a normal attribute to the class object. In
276-
doing so it respsects inheritance (all subclasses of Stream will also
276+
doing so it respects inheritance (all subclasses of Stream will also
277277
get the foo attribute).
278278
279279
By default callables are assumed to be instance methods. If you like
@@ -285,6 +285,15 @@ def register_api(cls, modifier=identity, attribute_name=None):
285285
... ...
286286
287287
>>> Stream.foo(...) # Foo operates as a static method
288+
289+
You can also provide an optional ``attribute_name`` argument to control
290+
the name of the attribute your callable will be attached as.
291+
292+
>>> @Stream.register_api(attribute_name="bar")
293+
... class foo(Stream):
294+
... ...
295+
296+
>> Stream().bar(...) # foo was actually attached as bar
288297
"""
289298
def _(func):
290299
@functools.wraps(func)
@@ -298,11 +307,17 @@ def wrapped(*args, **kwargs):
298307
@classmethod
299308
def register_plugin_entry_point(cls, entry_point, modifier=identity):
300309
def stub(*args, **kwargs):
301-
attribute = entry_point.load()
310+
node = entry_point.load()
311+
if not issubclass(node, Stream):
312+
raise TypeError(
313+
f"Error loading {entry_point.name} "
314+
f"from module {entry_point.module_name}: "
315+
f"{entry_point.cls.__name__} must be a subclass of Stream"
316+
)
302317
cls.register_api(
303318
modifier=modifier, attribute_name=entry_point.name
304-
)(attribute)
305-
return attribute(*args, **kwargs)
319+
)(node)
320+
return node(*args, **kwargs)
306321
cls.register_api(modifier=modifier, attribute_name=entry_point.name)(stub)
307322

308323
def start(self):

streamz/plugins.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
def load_plugins(cls):
55
for entry_point in pkg_resources.iter_entry_points("streamz.sources"):
6-
cls.register_plugin_entrypoint(entry_point, staticmethod)
6+
cls.register_plugin_entry_point(entry_point, staticmethod)
77
for entry_point in pkg_resources.iter_entry_points("streamz.nodes"):
8-
cls.register_plugin_entrypoint(entry_point)
8+
cls.register_plugin_entry_point(entry_point)
99
for entry_point in pkg_resources.iter_entry_points("streamz.sinks"):
10-
cls.register_plugin_entrypoint(entry_point)
10+
cls.register_plugin_entry_point(entry_point)

streamz/tests/test_plugins.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,55 @@
1-
from streamz.sources import Source
2-
from streamz import Stream
1+
import pytest
2+
from streamz import Source, Stream
33

44

55
class MockEntryPoint:
66

7-
def __init__(self, name, cls):
7+
def __init__(self, name, cls, module_name=None):
88
self.name = name
99
self.cls = cls
10+
self.module_name = module_name
1011

1112
def load(self):
1213
return self.cls
1314

1415

1516
def test_register_plugin_entry_point():
16-
class test(Stream):
17+
class test_stream(Stream):
1718
pass
1819

19-
entry_point = MockEntryPoint("test_node", test)
20+
entry_point = MockEntryPoint("test_node", test_stream)
2021
Stream.register_plugin_entry_point(entry_point)
2122

2223
assert Stream.test_node.__name__ == "stub"
2324

2425
Stream().test_node()
2526

26-
assert Stream.test_node.__name__ == "test"
27+
assert Stream.test_node.__name__ == "test_stream"
2728

2829

2930
def test_register_plugin_entry_point_modifier():
30-
class test(Source):
31+
class test_source(Source):
3132
pass
3233

33-
entry_point = MockEntryPoint("from_test", test)
34-
Stream.register_plugin_entry_point(entry_point, staticmethod)
34+
def modifier(fn):
35+
fn.__name__ = 'modified_name'
36+
return staticmethod(fn)
37+
38+
entry_point = MockEntryPoint("from_test", test_source)
39+
Stream.register_plugin_entry_point(entry_point, modifier)
3540

3641
Stream.from_test()
3742

38-
assert Stream.from_test.__self__ is Stream
43+
assert Stream.from_test.__name__ == "modified_name"
44+
45+
46+
def test_register_plugin_entry_point_raises():
47+
class invalid_node:
48+
pass
49+
50+
entry_point = MockEntryPoint("test", invalid_node, "test_module.test")
51+
52+
Stream.register_plugin_entry_point(entry_point)
53+
54+
with pytest.raises(TypeError):
55+
Stream.test()

0 commit comments

Comments
 (0)