Skip to content

Commit 96463b5

Browse files
committed
Make map_async restartable
1 parent c357729 commit 96463b5

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

streamz/core.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import asyncio
2+
import concurrent.futures
23
from collections import deque, defaultdict
34
from datetime import timedelta
45
from itertools import chain
56
import functools
67
import logging
78
import threading
89
from time import time
9-
from typing import Any, Callable, Hashable, Union
10+
from typing import Any, Callable, Coroutine, Hashable, Tuple, Union, overload
1011
import weakref
1112

1213
import toolz
@@ -756,28 +757,54 @@ def __init__(self, upstream, func, *args, parallelism=1, stop_on_exception=False
756757
stream_name = kwargs.pop('stream_name', None)
757758
self.kwargs = kwargs
758759
self.args = args
759-
self.running = True
760760
self.stop_on_exception = stop_on_exception
761761
self.work_queue = asyncio.Queue(maxsize=parallelism)
762762

763763
Stream.__init__(self, upstream, stream_name=stream_name, ensure_io_loop=True)
764-
self.work_task = self._create_task(self.work_callback())
764+
self.work_task = None
765+
766+
def _create_work_task(self) -> Tuple[asyncio.Event, asyncio.Task[None]]:
767+
stop_work = asyncio.Event()
768+
work_task = self._create_task(self.work_callback(stop_work))
769+
return stop_work, work_task
770+
771+
def start(self):
772+
if self.work_task:
773+
stop_work, _ = self.work_task
774+
stop_work.set()
775+
self.work_task = self._create_work_task()
776+
super().start()
765777

766778
def stop(self):
767-
if self.running:
768-
self.running = False
769-
super().stop()
779+
stop_work, _ = self.work_task
780+
stop_work.set()
781+
self.work_task = None
782+
super().stop()
770783

771784
def update(self, x, who=None, metadata=None):
785+
if not self.work_task:
786+
self.work_task = self._create_work_task()
772787
return self._create_task(self._insert_job(x, metadata))
773788

789+
@overload
790+
def _create_task(self, coro: asyncio.Future) -> asyncio.Future:
791+
...
792+
793+
@overload
794+
def _create_task(self, coro: concurrent.futures.Future) -> concurrent.futures.Future:
795+
...
796+
797+
@overload
798+
def _create_task(self, coro: Coroutine) -> asyncio.Task:
799+
...
800+
774801
def _create_task(self, coro):
775802
if gen.is_future(coro):
776803
return coro
777804
return self.loop.asyncio_loop.create_task(coro)
778805

779-
async def work_callback(self):
780-
while self.running:
806+
async def work_callback(self, stop_work: asyncio.Event):
807+
while not stop_work.is_set():
781808
task, metadata = await self.work_queue.get()
782809
self.work_queue.task_done()
783810
try:

streamz/tests/test_core.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,28 @@ def fail_func():
151151
assert (time() - start) == pytest.approx(0.1, abs=4e-3)
152152

153153

154+
@pytest.mark.asyncio
155+
async def test_map_async_restart():
156+
async def flake_out(x):
157+
if x == 2:
158+
raise RuntimeError("I fail on 2.")
159+
if x > 4:
160+
raise RuntimeError("I fail on > 4.")
161+
return x
162+
163+
source = Stream.from_iterable(itertools.count())
164+
mapped = source.map_async(flake_out, stop_on_exception=True)
165+
results = mapped.sink_to_list()
166+
source.start()
167+
168+
await await_for(lambda: results == [0, 1], 1)
169+
await await_for(lambda: not mapped.work_task, 1)
170+
171+
source.start()
172+
173+
await await_for(lambda: results == [0, 1, 3, 4], 1)
174+
175+
154176
@pytest.mark.asyncio
155177
async def test_map_async():
156178
@gen.coroutine

0 commit comments

Comments
 (0)