Skip to content

Commit ff956bc

Browse files
committed
compiler: Add cire-minmem optoption
1 parent 59c357b commit ff956bc

6 files changed

Lines changed: 68 additions & 8 deletions

File tree

devito/core/cpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def _normalize_kwargs(cls, **kwargs):
6161
o['cire-maxpar'] = oo.pop('cire-maxpar', False)
6262
o['cire-ftemps'] = oo.pop('cire-ftemps', False)
6363
o['cire-mingain'] = oo.pop('cire-mingain', cls.CIRE_MINGAIN)
64+
o['cire-minmem'] = oo.pop('cire-minmem', cls.CIRE_MINMEM)
6465
o['cire-schedule'] = oo.pop('cire-schedule', cls.CIRE_SCHEDULE)
6566

6667
# Shared-memory parallelism

devito/core/gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def _normalize_kwargs(cls, **kwargs):
6868
o['cire-maxpar'] = oo.pop('cire-maxpar', True)
6969
o['cire-ftemps'] = oo.pop('cire-ftemps', False)
7070
o['cire-mingain'] = oo.pop('cire-mingain', cls.CIRE_MINGAIN)
71+
o['cire-minmem'] = oo.pop('cire-minmem', cls.CIRE_MINMEM)
7172
o['cire-schedule'] = oo.pop('cire-schedule', cls.CIRE_SCHEDULE)
7273

7374
# GPU parallelism

devito/core/operator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ class BasicOperator(Operator):
6969
intensity of the generated kernel.
7070
"""
7171

72+
CIRE_MINMEM = True
73+
"""
74+
Minimize memory consumption when allocating temporaries for CIRE-optimized
75+
expressions. This may come at the cost of slighly worse performance due to
76+
the potential need for extra registers to hold a greater number of support
77+
variables (e.g., strides).
78+
"""
79+
7280
SCALAR_MIN_TYPE = np.float16
7381
"""
7482
Minimum datatype for a scalar arising from a common sub-expression or CIRE temp.

devito/passes/clusters/aliases.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
vmax, vmin)
1515
from devito.passes.clusters.cse import _cse
1616
from devito.symbolics import (Uxmapper, estimate_cost, search, reuse_if_untouched,
17-
uxreplace, sympy_dtype)
17+
retrieve_functions, uxreplace, sympy_dtype)
1818
from devito.tools import (Stamp, as_mapper, as_tuple, flatten, frozendict,
1919
is_integer, generator, split, timed_pass)
2020
from devito.types import (Eq, Symbol, Temp, TempArray, TempFunction,
@@ -113,6 +113,7 @@ def __init__(self, sregistry, options, platform):
113113
self.opt_rotate = options['cire-rotate']
114114
self.opt_ftemps = options['cire-ftemps']
115115
self.opt_mingain = options['cire-mingain']
116+
self.opt_minmem = options['cire-minmem']
116117
self.opt_min_dtype = options['scalar-min-type']
117118
self.opt_multisubdomain = True
118119

@@ -143,7 +144,8 @@ def _aliases_from_clusters(self, clusters, exclude, meta):
143144

144145
# Schedule -> [Clusters]_k
145146
processed, subs = lower_schedule(schedule, meta, self.sregistry,
146-
self.opt_ftemps, self.opt_min_dtype)
147+
self.opt_ftemps, self.opt_min_dtype,
148+
self.opt_minmem)
147149

148150
# [Clusters]_k -> [Clusters]_k (optimization)
149151
if self.opt_multisubdomain:
@@ -831,11 +833,12 @@ def optimize_schedule_rotations(schedule, sregistry):
831833
return schedule.rebuild(*processed, rmapper=rmapper)
832834

833835

834-
def lower_schedule(schedule, meta, sregistry, ftemps, min_dtype):
836+
def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
837+
opt_minmem):
835838
"""
836839
Turn a Schedule into a sequence of Clusters.
837840
"""
838-
if ftemps:
841+
if opt_ftemps:
839842
make = TempFunction
840843
else:
841844
# Typical case -- the user does *not* "see" the CIRE-created temporaries
@@ -865,8 +868,26 @@ def lower_schedule(schedule, meta, sregistry, ftemps, min_dtype):
865868
dimensions = [d.parent if d.is_AbstractSub else d
866869
for d in writeto.itdims]
867870

868-
# The halo must be set according to the size of `writeto`
869-
halo = [(abs(i.lower), abs(i.upper)) for i in writeto]
871+
# The minimum halo required along each Dimension depends on `writeto`.
872+
# The user might suggest to go more relaxed about this via `opt_minmem`,
873+
# in which case we extend the halo based on the surrounding
874+
# Functions to minimize support variables such as strides etc
875+
halo = {i.dim: Size(abs(i.lower), abs(i.upper)) for i in writeto}
876+
877+
if opt_minmem:
878+
functions = []
879+
else:
880+
functions = retrieve_functions(pivot)
881+
882+
for f in functions:
883+
for d, h0 in list(halo.items()):
884+
try:
885+
h1 = f._size_halo[d]
886+
except KeyError:
887+
continue
888+
halo[d] = Size(max(h0.left, h1.left), max(h0.right, h1.right))
889+
890+
halo = tuple(halo.values())
870891

871892
# The indices used to write into the Array
872893
indices = []
@@ -889,7 +910,7 @@ def lower_schedule(schedule, meta, sregistry, ftemps, min_dtype):
889910
# Degenerate case: scalar expression
890911
assert writeto.size == 0
891912

892-
dtype = sympy_dtype(pivot, base=meta.dtype, smin=min_dtype)
913+
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
893914
obj = Temp(name=name, dtype=dtype)
894915
expression = Eq(obj, uxreplace(pivot, subs))
895916

@@ -1037,6 +1058,9 @@ def pick_best(variants):
10371058
# Utilities
10381059

10391060

1061+
Size = namedtuple('Size', 'left right')
1062+
1063+
10401064
class Group(tuple):
10411065

10421066
"""

devito/passes/iet/linearization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def key1(f, d):
6767
6868
* False if not statically linearizable, that is not linearizable via
6969
constant symbolic sizes and strides;
70-
* A 3-tuple `(Dimension, halo size, grid)` otherwise.
70+
* A 3-tuple `(Dimension, halo size, pad dtype)` otherwise.
7171
"""
7272
if f.is_regular:
7373
# For paddable objects the following holds:

tests/test_linearize.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,3 +659,29 @@ def test_int64_array(order):
659659
else:
660660
long = '(long)'
661661
assert f'({2*order} + {long}y_size)*({2*order} + {long}x_size))' in str(op)
662+
663+
664+
def test_cire_n_strides():
665+
grid = Grid(shape=(4, 4, 4))
666+
667+
u = TimeFunction(name='u', grid=grid, space_order=8)
668+
u1 = TimeFunction(name='u', grid=grid, space_order=8)
669+
670+
eqn = Eq(u.forward, u.dy.dx + u.dy.dy + u.dy.dz + 1.)
671+
672+
op0 = Operator(eqn, opt=('advanced', {'linearize': False, 'cire-mingain': 0}))
673+
op1 = Operator(eqn, opt=('advanced', {'linearize': True, 'cire-mingain': 0}))
674+
op2 = Operator(eqn, opt=('advanced', {'linearize': True,
675+
'cire-mingain': 0,
676+
'cire-minmem': False}))
677+
678+
# Check generated code
679+
assert 'uL0' in str(op1)
680+
assert len(op1.body.strides) == 11
681+
assert 'uL0' in str(op2)
682+
assert len(op2.body.strides) == 9 # Fewer size/stride vars thx to cire-minmem
683+
684+
op0.apply(time_M=10)
685+
op2.apply(time_M=10, u=u1)
686+
687+
assert np.all(u.data == u1.data)

0 commit comments

Comments
 (0)