@@ -121,6 +121,13 @@ class Schedule(Queue):
121121 Dimension in both Clusters.
122122 """
123123
124+ FISSION_THRESHOLD = 2
125+ """
126+ The maximum number of iteration Dimensions such that we consider fissioning
127+ a sequence of Clusters to increase parallelism. IOW, if there are more than
128+ this number of iteration Dimensions, we do not even try to fission.
129+ """
130+
124131 @timed_pass (name = 'schedule' )
125132 def process (self , clusters ):
126133 return self ._process_fatd (clusters , 1 )
@@ -134,7 +141,8 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
134141
135142 # Take the innermost Dimension -- no other Clusters other than those in
136143 # `clusters` are supposed to share it
137- candidates = prefix [- 1 ].dim ._defines
144+ dim = prefix [- 1 ].dim
145+ candidates = dim ._defines
138146
139147 scope = Scope (flatten (c .exprs for c in clusters ))
140148
@@ -157,7 +165,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
157165 # Schedule Clusters over different IterationSpaces if this increases
158166 # parallelism
159167 for i in range (1 , len (clusters )):
160- if self ._break_for_parallelism (scope , candidates , i ):
168+ if self ._break_for_parallelism (scope , dim , i ):
161169 return self .callback (clusters [:i ], prefix , clusters [i :] + backlog ,
162170 candidates | known_break )
163171
@@ -189,7 +197,19 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
189197
190198 return processed + self .callback (backlog , prefix )
191199
192- def _break_for_parallelism (self , scope , candidates , timestamp ):
200+ def _break_for_parallelism (self , scope , dim , timestamp ):
201+ candidates = dim ._defines
202+
203+ # Do not fission for data locality reasons if there's enough potential
204+ # parallelism in the inner Dimensions
205+ try :
206+ ispace , = {e .ispace for e in scope .exprs [:timestamp ]}
207+ _ , ispace1 = ispace .split (dim )
208+ if len (ispace1 .itdims ) > self .FISSION_THRESHOLD :
209+ return False
210+ except ValueError :
211+ pass
212+
193213 # `test` will be True if there's at least one data-dependence that would
194214 # break parallelism
195215 test = False
0 commit comments