Skip to content

Commit a336b0c

Browse files
author
Martin Durant
committed
refactor and make less tornado
1 parent b82ca42 commit a336b0c

4 files changed

Lines changed: 93 additions & 144 deletions

File tree

streamz/compatibility.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

streamz/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from collections.abc import Iterable
2525

26-
from .compatibility import get_thread_identity
26+
from threading import get_ident as get_thread_identity
2727
from .orderedweakset import OrderedWeakrefSet
2828

2929
no_default = '--no-default--'

streamz/sources.py

Lines changed: 83 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import asyncio
12
from glob import glob
23
import os
3-
4+
import inspect
45
import time
56
import tornado.ioloop
67
from tornado import gen
@@ -35,15 +36,26 @@ def write(text):
3536
class Source(Stream):
3637
_graphviz_shape = 'doubleoctagon'
3738

38-
def __init__(self, **kwargs):
39+
def __init__(self, start=False, **kwargs):
3940
self.stopped = True
40-
super(Source, self).__init__(**kwargs)
41+
super().__init__(ensure_io_loop=True, **kwargs)
42+
if start:
43+
self.start()
4144

42-
def stop(self): # pragma: no cover
43-
# fallback stop method - for poll functions with while not self.stopped
45+
def stop(self):
46+
"""set self.stopped, which will cause polling to stop after next run"""
4447
if not self.stopped:
4548
self.stopped = True
4649

50+
def start(self):
51+
"""start polling"""
52+
self.stopped = False
53+
self.loop.add_callback(self.run)
54+
55+
async def run(self):
56+
while not self.stopped:
57+
await self._run()
58+
4759

4860
@Stream.register_api(staticmethod)
4961
class from_textfile(Source):
@@ -74,44 +86,32 @@ class from_textfile(Source):
7486
-------
7587
Stream
7688
"""
77-
def __init__(self, f, poll_interval=0.100, delimiter='\n', start=False,
89+
def __init__(self, f, poll_interval=0.100, delimiter='\n',
7890
from_end=False, **kwargs):
7991
if isinstance(f, str):
8092
f = open(f)
93+
self.buffer = ''
94+
if self.from_end:
95+
# this only happens when we are ready to read
96+
self.file.seek(0, 2)
8197
self.file = f
8298
self.from_end = from_end
8399
self.delimiter = delimiter
84100

85101
self.poll_interval = poll_interval
86-
super(from_textfile, self).__init__(ensure_io_loop=True, **kwargs)
87-
self.stopped = True
88-
self.started = False
89-
if start:
90-
self.start()
102+
super().__init__(**kwargs)
91103

92-
def start(self):
93-
self.stopped = False
94-
self.started = False
95-
self.loop.add_callback(self.do_poll)
96-
97-
@gen.coroutine
98-
def do_poll(self):
99-
buffer = ''
100-
if self.from_end:
101-
# this only happens when we are ready to read
102-
self.file.seek(0, 2)
103-
while not self.stopped:
104-
self.started = True
105-
line = self.file.read()
106-
if line:
107-
buffer = buffer + line
108-
if self.delimiter in buffer:
109-
parts = buffer.split(self.delimiter)
110-
buffer = parts.pop(-1)
111-
for part in parts:
112-
yield self._emit(part + self.delimiter)
113-
else:
114-
yield gen.sleep(self.poll_interval)
104+
async def _run(self):
105+
line = self.file.read()
106+
if line:
107+
self.buffer = self.buffer + line
108+
if self.delimiter in self.buffer:
109+
parts = self.buffer.split(self.delimiter)
110+
self.buffer = parts.pop(-1)
111+
for part in parts:
112+
await self._emit(part + self.delimiter)
113+
else:
114+
await asyncio.sleep(self.poll_interval)
115115

116116

117117
@Stream.register_api(staticmethod)
@@ -133,7 +133,7 @@ class filenames(Source):
133133
>>> source = Stream.filenames('path/to/dir') # doctest: +SKIP
134134
>>> source = Stream.filenames('path/to/*.csv', poll_interval=0.500) # doctest: +SKIP
135135
"""
136-
def __init__(self, path, poll_interval=0.100, start=False, **kwargs):
136+
def __init__(self, path, poll_interval=0.100, **kwargs):
137137
if '*' not in path:
138138
if os.path.isdir(path):
139139
if not path.endswith(os.path.sep):
@@ -142,26 +142,15 @@ def __init__(self, path, poll_interval=0.100, start=False, **kwargs):
142142
self.path = path
143143
self.seen = set()
144144
self.poll_interval = poll_interval
145-
self.stopped = True
146-
super(filenames, self).__init__(ensure_io_loop=True)
147-
if start:
148-
self.start()
145+
super().__init__(**kwargs)
149146

150-
def start(self):
151-
self.stopped = False
152-
self.loop.add_callback(self.do_poll)
153-
154-
@gen.coroutine
155-
def do_poll(self):
156-
while True:
157-
filenames = set(glob(self.path))
158-
new = filenames - self.seen
159-
for fn in sorted(new):
160-
self.seen.add(fn)
161-
yield self._emit(fn)
162-
yield gen.sleep(self.poll_interval) # TODO: remove poll if delayed
163-
if self.stopped:
164-
break
147+
async def _run(self):
148+
filenames = set(glob(self.path))
149+
new = filenames - self.seen
150+
for fn in sorted(new):
151+
self.seen.add(fn)
152+
await self._emit(fn)
153+
await asyncio.sleep(self.poll_interval) # TODO: remove poll if delayed
165154

166155

167156
@Stream.register_api(staticmethod)
@@ -191,42 +180,31 @@ class from_tcp(Source):
191180
192181
>>> source = Source.from_tcp(4567) # doctest: +SKIP
193182
"""
194-
def __init__(self, port, delimiter=b'\n', start=False,
195-
server_kwargs=None):
196-
super(from_tcp, self).__init__(ensure_io_loop=True)
197-
self.stopped = True
183+
def __init__(self, port, delimiter=b'\n', server_kwargs=None, **kwargs):
198184
self.server_kwargs = server_kwargs or {}
199185
self.port = port
200186
self.server = None
201187
self.delimiter = delimiter
202-
if start: # pragma: no cover
203-
self.start()
188+
super().__init__(**kwargs)
204189

205-
@gen.coroutine
206-
def _start_server(self):
190+
def run(self):
207191
from tornado.tcpserver import TCPServer
208192
from tornado.iostream import StreamClosedError
209193

210194
class EmitServer(TCPServer):
211195
source = self
212196

213-
@gen.coroutine
214-
def handle_stream(self, stream, address):
215-
while True:
197+
async def handle_stream(self, stream, address):
198+
while not self.source.stopped:
216199
try:
217-
data = yield stream.read_until(self.source.delimiter)
218-
yield self.source._emit(data)
200+
data = await stream.read_until(self.source.delimiter)
201+
await self.source._emit(data)
219202
except StreamClosedError:
220203
break
221204

222205
self.server = EmitServer(**self.server_kwargs)
223206
self.server.listen(self.port)
224207

225-
def start(self):
226-
if self.stopped:
227-
self.loop.add_callback(self._start_server)
228-
self.stopped = False
229-
230208
def stop(self):
231209
if not self.stopped:
232210
self.server.stop()
@@ -260,26 +238,22 @@ class from_http_server(Source):
260238
261239
"""
262240

263-
def __init__(self, port, path='/.*', start=False, server_kwargs=None):
241+
def __init__(self, port, path='/.*', server_kwargs=None, **kwargs):
264242
self.port = port
265243
self.path = path
266244
self.server_kwargs = server_kwargs or {}
267-
super(from_http_server, self).__init__(ensure_io_loop=True)
268-
self.stopped = True
269245
self.server = None
270-
if start: # pragma: no cover
271-
self.start()
246+
super().__init__(**kwargs)
272247

273-
def _start_server(self):
248+
def run(self):
274249
from tornado.web import Application, RequestHandler
275250
from tornado.httpserver import HTTPServer
276251

277252
class Handler(RequestHandler):
278253
source = self
279254

280-
@gen.coroutine
281-
def post(self):
282-
yield self.source._emit(self.request.body)
255+
async def post(self):
256+
await asyncio.gather(*self.source._emit(self.request.body))
283257
self.write('OK')
284258

285259
application = Application([
@@ -288,12 +262,6 @@ def post(self):
288262
self.server = HTTPServer(application, **self.server_kwargs)
289263
self.server.listen(self.port)
290264

291-
def start(self):
292-
"""Start HTTP server and listen"""
293-
if self.stopped:
294-
self.loop.add_callback(self._start_server)
295-
self.stopped = False
296-
297265
def stop(self):
298266
"""Shutdown HTTP server"""
299267
if not self.stopped:
@@ -325,46 +293,36 @@ class from_process(Source):
325293
>>> source = Source.from_process(['ping', 'localhost']) # doctest: +SKIP
326294
"""
327295

328-
def __init__(self, cmd, open_kwargs=None, with_stderr=False, start=False):
296+
def __init__(self, cmd, open_kwargs=None, with_stderr=False, with_end=True,
297+
**kwargs):
329298
self.cmd = cmd
330299
self.open_kwargs = open_kwargs or {}
331300
self.with_stderr = with_stderr
332-
super(from_process, self).__init__(ensure_io_loop=True)
333-
self.stopped = True
301+
self.with_end = with_end
334302
self.process = None
335-
if start: # pragma: no cover
336-
self.start()
303+
super().__init__(**kwargs)
337304

338-
@gen.coroutine
339-
def _start_process(self):
340-
# should be done in asyncio (py3 only)? Apparently can handle Windows
341-
# with appropriate config.
342-
from tornado.process import Subprocess
343-
from tornado.iostream import StreamClosedError
305+
async def run(self):
344306
import subprocess
345-
stderr = subprocess.STDOUT if self.with_stderr else subprocess.PIPE
346-
process = Subprocess(self.cmd, stdout=Subprocess.STREAM,
347-
stderr=stderr, **self.open_kwargs)
307+
stderr = subprocess.STDOUT if self.with_stderr else None
308+
if isinstance(self.cmd, (list, tuple)):
309+
cmd, *args = self.cmd
310+
else:
311+
cmd, args = self.cmd, ()
312+
process = await asyncio.create_subprocess_exec(
313+
cmd, *args, stdout=subprocess.PIPE,
314+
stderr=stderr, **self.open_kwargs)
348315
while not self.stopped:
349316
try:
350-
out = yield process.stdout.read_until(b'\n')
351-
except StreamClosedError:
352-
# process exited
353-
break
354-
yield self._emit(out)
355-
yield process.stdout.close()
356-
process.proc.terminate()
357-
358-
def start(self):
359-
"""Start external process"""
360-
if self.stopped:
361-
self.loop.add_callback(self._start_process)
362-
self.stopped = False
363-
364-
def stop(self):
365-
"""Shutdown external process"""
366-
if not self.stopped:
367-
self.stopped = True
317+
out = await process.stdout.readuntil(b'\n')
318+
except asyncio.IncompleteReadError as err:
319+
if self.with_end:
320+
out = err.partial
321+
else:
322+
break
323+
await asyncio.gather(*self._emit(out))
324+
process.terminate()
325+
await process.wait()
368326

369327

370328
@Stream.register_api(staticmethod)
@@ -401,15 +359,12 @@ class from_kafka(Source):
401359
... 'group.id': 'streamz'}) # doctest: +SKIP
402360
403361
"""
404-
def __init__(self, topics, consumer_params, poll_interval=0.1, start=False, **kwargs):
362+
def __init__(self, topics, consumer_params, poll_interval=0.1, **kwargs):
405363
self.cpars = consumer_params
406364
self.consumer = None
407365
self.topics = topics
408366
self.poll_interval = poll_interval
409-
super(from_kafka, self).__init__(ensure_io_loop=True, **kwargs)
410-
self.stopped = True
411-
if start:
412-
self.start()
367+
super().__init__(**kwargs)
413368

414369
def do_poll(self):
415370
if self.consumer is not None:
@@ -470,10 +425,9 @@ def __init__(self, topic, consumer_params, poll_interval='1s',
470425
self.max_batch_size = max_batch_size
471426
self.keys = keys
472427
self.engine = engine
473-
self.stopped = True
474428
self.started = False
475429

476-
super(FromKafkaBatched, self).__init__(ensure_io_loop=True, **kwargs)
430+
super().__init__(**kwargs)
477431

478432
@gen.coroutine
479433
def poll_kafka(self):
@@ -753,17 +707,12 @@ class from_iterable(Source):
753707
"""
754708

755709
def __init__(self, iterable, **kwargs):
756-
super().__init__(ensure_io_loop=True, **kwargs)
757710
self._iterable = iterable
711+
super().__init__(**kwargs)
758712

759-
def start(self):
760-
self.stopped = False
761-
self.loop.add_callback(self._run)
762-
763-
@gen.coroutine
764-
def _run(self):
713+
async def run(self):
765714
for x in self._iterable:
766715
if self.stopped:
767716
break
768-
yield self._emit(x)
717+
await asyncio.gather(*self._emit(x))
769718
self.stopped = True

streamz/tests/test_sources.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,19 @@ def test_http():
8686
requests.post('http://localhost:%i/other' % port, data=b'data2')
8787

8888

89-
@flaky(max_runs=3, min_passes=1)
89+
#@flaky(max_runs=3, min_passes=1)
9090
@gen_test(timeout=60)
9191
def test_process():
92+
import sys
93+
import asyncio
9294
cmd = ["python", "-c", "for i in range(4): print(i)"]
9395
s = Source.from_process(cmd)
96+
if sys.platform != "win32":
97+
# don't know why - something with pytest and new processes
98+
policy = asyncio.get_event_loop_policy()
99+
watcher = asyncio.SafeChildWatcher()
100+
policy.set_child_watcher(watcher)
101+
watcher.attach_loop(s.loop.asyncio_loop)
94102
out = s.sink_to_list()
95103
s.start()
96104
yield await_for(lambda: out == [b'0\n', b'1\n', b'2\n', b'3\n'], timeout=5)

0 commit comments

Comments
 (0)