|
1 | 1 | import asyncio |
| 2 | +import concurrent.futures |
2 | 3 | from collections import deque, defaultdict |
3 | 4 | from datetime import timedelta |
4 | 5 | from itertools import chain |
5 | 6 | import functools |
6 | 7 | import logging |
7 | 8 | import threading |
8 | 9 | from time import time |
9 | | -from typing import Any, Callable, Hashable, Union |
| 10 | +from typing import Any, Callable, Coroutine, Hashable, Tuple, Union, overload |
10 | 11 | import weakref |
11 | 12 |
|
12 | 13 | import toolz |
@@ -756,28 +757,54 @@ def __init__(self, upstream, func, *args, parallelism=1, stop_on_exception=False |
756 | 757 | stream_name = kwargs.pop('stream_name', None) |
757 | 758 | self.kwargs = kwargs |
758 | 759 | self.args = args |
759 | | - self.running = True |
760 | 760 | self.stop_on_exception = stop_on_exception |
761 | 761 | self.work_queue = asyncio.Queue(maxsize=parallelism) |
762 | 762 |
|
763 | 763 | Stream.__init__(self, upstream, stream_name=stream_name, ensure_io_loop=True) |
764 | | - self.work_task = self._create_task(self.work_callback()) |
| 764 | + self.work_task = None |
| 765 | + |
| 766 | + def _create_work_task(self) -> Tuple[asyncio.Event, asyncio.Task[None]]: |
| 767 | + stop_work = asyncio.Event() |
| 768 | + work_task = self._create_task(self.work_callback(stop_work)) |
| 769 | + return stop_work, work_task |
| 770 | + |
| 771 | + def start(self): |
| 772 | + if self.work_task: |
| 773 | + stop_work, _ = self.work_task |
| 774 | + stop_work.set() |
| 775 | + self.work_task = self._create_work_task() |
| 776 | + super().start() |
765 | 777 |
|
766 | 778 | def stop(self): |
767 | | - if self.running: |
768 | | - self.running = False |
769 | | - super().stop() |
| 779 | + stop_work, _ = self.work_task |
| 780 | + stop_work.set() |
| 781 | + self.work_task = None |
| 782 | + super().stop() |
770 | 783 |
|
771 | 784 | def update(self, x, who=None, metadata=None): |
| 785 | + if not self.work_task: |
| 786 | + self.work_task = self._create_work_task() |
772 | 787 | return self._create_task(self._insert_job(x, metadata)) |
773 | 788 |
|
| 789 | + @overload |
| 790 | + def _create_task(self, coro: asyncio.Future) -> asyncio.Future: |
| 791 | + ... |
| 792 | + |
| 793 | + @overload |
| 794 | + def _create_task(self, coro: concurrent.futures.Future) -> concurrent.futures.Future: |
| 795 | + ... |
| 796 | + |
| 797 | + @overload |
| 798 | + def _create_task(self, coro: Coroutine) -> asyncio.Task: |
| 799 | + ... |
| 800 | + |
774 | 801 | def _create_task(self, coro): |
775 | 802 | if gen.is_future(coro): |
776 | 803 | return coro |
777 | 804 | return self.loop.asyncio_loop.create_task(coro) |
778 | 805 |
|
779 | | - async def work_callback(self): |
780 | | - while self.running: |
| 806 | + async def work_callback(self, stop_work: asyncio.Event): |
| 807 | + while not stop_work.is_set(): |
781 | 808 | task, metadata = await self.work_queue.get() |
782 | 809 | self.work_queue.task_done() |
783 | 810 | try: |
|
0 commit comments