@@ -111,6 +111,7 @@ class CireTransformer:
111111 def __init__ (self , sregistry , options , platform ):
112112 self .sregistry = sregistry
113113 self .platform = platform
114+
114115 self .opt_minstorage = options ['min-storage' ]
115116 self .opt_rotate = options ['cire-rotate' ]
116117 self .opt_ftemps = options ['cire-ftemps' ]
@@ -125,7 +126,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
125126 for mapper in self ._generate (cgroup , exclude ):
126127 # Clusters -> AliasList
127128 found = collect (mapper .extracted , meta .ispace , self .opt_minstorage )
128- exprs , aliases = choose (found , cgroup , mapper , self . opt_mingain )
129+ exprs , aliases = self . _choose (found , cgroup , mapper )
129130
130131 # AliasList -> Schedule
131132 schedule = lower_aliases (aliases , meta , self .opt_maxpar )
@@ -189,6 +190,52 @@ def _lookup_key(self, c):
189190 """
190191 raise NotImplementedError
191192
193+ def _choose (self , aliases , cgroup , mapper ):
194+ """
195+ Analyze the detected aliases and, after applying a cost model to rule
196+ out the aliases with a bad memory/flops trade-off, inject them into the
197+ original expressions.
198+ """
199+ exprs = cgroup .exprs
200+
201+ aliases = AliasList (aliases )
202+ if not aliases :
203+ return exprs , aliases
204+
205+ # `score < m` => discarded
206+ # `score > M` => optimized
207+ # `m <= score <= M` => maybe optimized, depends on heuristics
208+ m = self .opt_mingain
209+ M = self .opt_mingain * 3
210+
211+ # Filter off the aliases with low score
212+ key = lambda a : a .score >= m
213+ aliases .filter (key )
214+
215+ # Project the candidate aliases into `exprs` to derive the final
216+ # working set
217+ mapper = {k : v for k , v in mapper .items ()
218+ if v .free_symbols & set (aliases .aliaseds )}
219+ templated = [uxreplace (e , mapper ) for e in exprs ]
220+ owset = wset (templated )
221+
222+ # Filter off the aliases with a weak flop-reduction / working-set tradeoff
223+ key = lambda a : \
224+ a .score > M or \
225+ m <= a .score <= M and (max (len (wset (a .pivot )), 1 ) >
226+ len (wset (a .pivot ) & owset ))
227+ aliases .filter (key )
228+
229+ if not aliases :
230+ return exprs , aliases
231+
232+ # Substitute the chosen aliasing sub-expressions
233+ mapper = {k : v for k , v in mapper .items ()
234+ if v .free_symbols & set (aliases .aliaseds )}
235+ exprs = [uxreplace (e , mapper ) for e in exprs ]
236+
237+ return exprs , aliases
238+
192239 def _select (self , variants ):
193240 """
194241 Select the best variant out of a set of `variants`, weighing flops and
@@ -611,49 +658,6 @@ def collect(extracted, ispace, minstorage):
611658 return aliases
612659
613660
614- def choose (aliases , cgroup , mapper , mingain ):
615- """
616- Analyze the detected aliases and, after applying a cost model to rule out
617- the aliases with a bad memory/flops trade-off, inject them into the original
618- expressions.
619- """
620- exprs = cgroup .exprs
621-
622- aliases = AliasList (aliases )
623- if not aliases :
624- return exprs , aliases
625-
626- # `score < m` => discarded
627- # `score > M` => optimized
628- # `m <= score <= M` => maybe discarded, maybe optimized; depends on heuristics
629- m = mingain
630- M = mingain * 3
631-
632- # Filter off the aliases with low score
633- key = lambda a : a .score >= m
634- aliases .filter (key )
635-
636- # Project the candidate aliases into `exprs` to derive the final working set
637- mapper = {k : v for k , v in mapper .items () if v .free_symbols & set (aliases .aliaseds )}
638- templated = [uxreplace (e , mapper ) for e in exprs ]
639- owset = wset (templated )
640-
641- # Filter off the aliases with a weak flop-reduction / working-set tradeoff
642- key = lambda a : \
643- a .score > M or \
644- m <= a .score <= M and max (len (wset (a .pivot )), 1 ) > len (wset (a .pivot ) & owset )
645- aliases .filter (key )
646-
647- if not aliases :
648- return exprs , aliases
649-
650- # Substitute the chosen aliasing sub-expressions
651- mapper = {k : v for k , v in mapper .items () if v .free_symbols & set (aliases .aliaseds )}
652- exprs = [uxreplace (e , mapper ) for e in exprs ]
653-
654- return exprs , aliases
655-
656-
657661def lower_aliases (aliases , meta , maxpar ):
658662 """
659663 Create a Schedule from an AliasList.
0 commit comments