1414 vmax , vmin )
1515from devito .passes .clusters .cse import _cse
1616from devito .symbolics import (Uxmapper , estimate_cost , search , reuse_if_untouched ,
17- uxreplace , sympy_dtype )
17+ retrieve_functions , uxreplace , sympy_dtype )
1818from devito .tools import (Stamp , as_mapper , as_tuple , flatten , frozendict ,
1919 is_integer , generator , split , timed_pass )
2020from devito .types import (Eq , Symbol , Temp , TempArray , TempFunction ,
@@ -113,6 +113,7 @@ def __init__(self, sregistry, options, platform):
113113 self .opt_rotate = options ['cire-rotate' ]
114114 self .opt_ftemps = options ['cire-ftemps' ]
115115 self .opt_mingain = options ['cire-mingain' ]
116+ self .opt_minmem = options ['cire-minmem' ]
116117 self .opt_min_dtype = options ['scalar-min-type' ]
117118 self .opt_multisubdomain = True
118119
@@ -143,7 +144,8 @@ def _aliases_from_clusters(self, clusters, exclude, meta):
143144
144145 # Schedule -> [Clusters]_k
145146 processed , subs = lower_schedule (schedule , meta , self .sregistry ,
146- self .opt_ftemps , self .opt_min_dtype )
147+ self .opt_ftemps , self .opt_min_dtype ,
148+ self .opt_minmem )
147149
148150 # [Clusters]_k -> [Clusters]_k (optimization)
149151 if self .opt_multisubdomain :
@@ -831,11 +833,12 @@ def optimize_schedule_rotations(schedule, sregistry):
831833 return schedule .rebuild (* processed , rmapper = rmapper )
832834
833835
834- def lower_schedule (schedule , meta , sregistry , ftemps , min_dtype ):
836+ def lower_schedule (schedule , meta , sregistry , opt_ftemps , opt_min_dtype ,
837+ opt_minmem ):
835838 """
836839 Turn a Schedule into a sequence of Clusters.
837840 """
838- if ftemps :
841+ if opt_ftemps :
839842 make = TempFunction
840843 else :
841844 # Typical case -- the user does *not* "see" the CIRE-created temporaries
@@ -865,8 +868,26 @@ def lower_schedule(schedule, meta, sregistry, ftemps, min_dtype):
865868 dimensions = [d .parent if d .is_AbstractSub else d
866869 for d in writeto .itdims ]
867870
868- # The halo must be set according to the size of `writeto`
869- halo = [(abs (i .lower ), abs (i .upper )) for i in writeto ]
871+ # The minimum halo required along each Dimension depends on `writeto`.
872+ # The user might suggest to go more relaxed about this via `opt_minmem`,
873+ # in which case we extend the halo based on the surrounding
874+ # Functions to minimize support variables such as strides etc
875+ halo = {i .dim : Size (abs (i .lower ), abs (i .upper )) for i in writeto }
876+
877+ if opt_minmem :
878+ functions = []
879+ else :
880+ functions = retrieve_functions (pivot )
881+
882+ for f in functions :
883+ for d , h0 in list (halo .items ()):
884+ try :
885+ h1 = f ._size_halo [d ]
886+ except KeyError :
887+ continue
888+ halo [d ] = Size (max (h0 .left , h1 .left ), max (h0 .right , h1 .right ))
889+
890+ halo = tuple (halo .values ())
870891
871892 # The indices used to write into the Array
872893 indices = []
@@ -889,7 +910,7 @@ def lower_schedule(schedule, meta, sregistry, ftemps, min_dtype):
889910 # Degenerate case: scalar expression
890911 assert writeto .size == 0
891912
892- dtype = sympy_dtype (pivot , base = meta .dtype , smin = min_dtype )
913+ dtype = sympy_dtype (pivot , base = meta .dtype , smin = opt_min_dtype )
893914 obj = Temp (name = name , dtype = dtype )
894915 expression = Eq (obj , uxreplace (pivot , subs ))
895916
@@ -1037,6 +1058,9 @@ def pick_best(variants):
10371058# Utilities
10381059
10391060
1061+ Size = namedtuple ('Size' , 'left right' )
1062+
1063+
10401064class Group (tuple ):
10411065
10421066 """
0 commit comments