@@ -165,8 +165,10 @@ def __init__(self, upstream=None, upstreams=None, stream_name=None,
165165 self .downstreams = OrderedWeakrefSet ()
166166 if upstreams is not None :
167167 self .upstreams = list (upstreams )
168- else :
168+ elif upstream is not None :
169169 self .upstreams = [upstream ]
170+ else :
171+ self .upstreams = []
170172
171173 self ._set_asynchronous (asynchronous )
172174 self ._set_loop (loop )
@@ -236,10 +238,7 @@ def _inform_asynchronous(self, asynchronous):
236238 def _add_upstream (self , upstream ):
237239 """Add upstream to current upstreams, this method is overridden for
238240 classes which handle stream specific buffers/caches"""
239- if self .upstreams == [None ]:
240- self .upstreams [0 ] = upstream
241- else :
242- self .upstreams .append (upstream )
241+ self .upstreams .append (upstream )
243242
244243 def _add_downstream (self , downstream ):
245244 """Add downstream to current downstreams"""
@@ -252,10 +251,7 @@ def _remove_downstream(self, downstream):
252251 def _remove_upstream (self , upstream ):
253252 """Remove upstream from current upstreams, this method is overridden for
254253 classes which handle stream specific buffers/caches"""
255- if len (self .upstreams ) == 1 :
256- self .upstreams [0 ] = [None ]
257- else :
258- self .upstreams .remove (upstream )
254+ self .upstreams .remove (upstream )
259255
260256 @classmethod
261257 def register_api (cls , modifier = identity , attribute_name = None ):
@@ -527,8 +523,8 @@ def destroy(self, streams=None):
527523 if streams is None :
528524 streams = self .upstreams
529525 for upstream in list (streams ):
530- upstream .downstreams . remove (self )
531- self .upstreams . remove (upstream )
526+ upstream ._remove_downstream (self )
527+ self ._remove_upstream (upstream )
532528
533529 def scatter (self , ** kwargs ):
534530 from .dask import scatter
0 commit comments