Skip to content

Commit 1c65e48

Browse files
Merge pull request #2839 from devitocodes/revisit-cire-again
compiler: Turn aliases' choose() into an instance method
2 parents e40fe22 + 6606bc8 commit 1c65e48

1 file changed

Lines changed: 48 additions & 44 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class CireTransformer:
111111
def __init__(self, sregistry, options, platform):
112112
self.sregistry = sregistry
113113
self.platform = platform
114+
114115
self.opt_minstorage = options['min-storage']
115116
self.opt_rotate = options['cire-rotate']
116117
self.opt_ftemps = options['cire-ftemps']
@@ -125,7 +126,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
125126
for mapper in self._generate(cgroup, exclude):
126127
# Clusters -> AliasList
127128
found = collect(mapper.extracted, meta.ispace, self.opt_minstorage)
128-
exprs, aliases = choose(found, cgroup, mapper, self.opt_mingain)
129+
exprs, aliases = self._choose(found, cgroup, mapper)
129130

130131
# AliasList -> Schedule
131132
schedule = lower_aliases(aliases, meta, self.opt_maxpar)
@@ -189,6 +190,52 @@ def _lookup_key(self, c):
189190
"""
190191
raise NotImplementedError
191192

193+
def _choose(self, aliases, cgroup, mapper):
194+
"""
195+
Analyze the detected aliases and, after applying a cost model to rule
196+
out the aliases with a bad memory/flops trade-off, inject them into the
197+
original expressions.
198+
"""
199+
exprs = cgroup.exprs
200+
201+
aliases = AliasList(aliases)
202+
if not aliases:
203+
return exprs, aliases
204+
205+
# `score < m` => discarded
206+
# `score > M` => optimized
207+
# `m <= score <= M` => maybe optimized, depends on heuristics
208+
m = self.opt_mingain
209+
M = self.opt_mingain*3
210+
211+
# Filter off the aliases with low score
212+
key = lambda a: a.score >= m
213+
aliases.filter(key)
214+
215+
# Project the candidate aliases into `exprs` to derive the final
216+
# working set
217+
mapper = {k: v for k, v in mapper.items()
218+
if v.free_symbols & set(aliases.aliaseds)}
219+
templated = [uxreplace(e, mapper) for e in exprs]
220+
owset = wset(templated)
221+
222+
# Filter off the aliases with a weak flop-reduction / working-set tradeoff
223+
key = lambda a: \
224+
a.score > M or \
225+
m <= a.score <= M and (max(len(wset(a.pivot)), 1) >
226+
len(wset(a.pivot) & owset))
227+
aliases.filter(key)
228+
229+
if not aliases:
230+
return exprs, aliases
231+
232+
# Substitute the chosen aliasing sub-expressions
233+
mapper = {k: v for k, v in mapper.items()
234+
if v.free_symbols & set(aliases.aliaseds)}
235+
exprs = [uxreplace(e, mapper) for e in exprs]
236+
237+
return exprs, aliases
238+
192239
def _select(self, variants):
193240
"""
194241
Select the best variant out of a set of `variants`, weighing flops and
@@ -611,49 +658,6 @@ def collect(extracted, ispace, minstorage):
611658
return aliases
612659

613660

614-
def choose(aliases, cgroup, mapper, mingain):
615-
"""
616-
Analyze the detected aliases and, after applying a cost model to rule out
617-
the aliases with a bad memory/flops trade-off, inject them into the original
618-
expressions.
619-
"""
620-
exprs = cgroup.exprs
621-
622-
aliases = AliasList(aliases)
623-
if not aliases:
624-
return exprs, aliases
625-
626-
# `score < m` => discarded
627-
# `score > M` => optimized
628-
# `m <= score <= M` => maybe discarded, maybe optimized; depends on heuristics
629-
m = mingain
630-
M = mingain*3
631-
632-
# Filter off the aliases with low score
633-
key = lambda a: a.score >= m
634-
aliases.filter(key)
635-
636-
# Project the candidate aliases into `exprs` to derive the final working set
637-
mapper = {k: v for k, v in mapper.items() if v.free_symbols & set(aliases.aliaseds)}
638-
templated = [uxreplace(e, mapper) for e in exprs]
639-
owset = wset(templated)
640-
641-
# Filter off the aliases with a weak flop-reduction / working-set tradeoff
642-
key = lambda a: \
643-
a.score > M or \
644-
m <= a.score <= M and max(len(wset(a.pivot)), 1) > len(wset(a.pivot) & owset)
645-
aliases.filter(key)
646-
647-
if not aliases:
648-
return exprs, aliases
649-
650-
# Substitute the chosen aliasing sub-expressions
651-
mapper = {k: v for k, v in mapper.items() if v.free_symbols & set(aliases.aliaseds)}
652-
exprs = [uxreplace(e, mapper) for e in exprs]
653-
654-
return exprs, aliases
655-
656-
657661
def lower_aliases(aliases, meta, maxpar):
658662
"""
659663
Create a Schedule from an AliasList.

0 commit comments

Comments
 (0)