Skip to content

Commit a2f0eb2

Browse files
committed
compiler: Add TempArray.shift
1 parent 571833f commit a2f0eb2

2 files changed

Lines changed: 16 additions & 2 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,8 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
913913
indices.append(i.dim - i.lower + s)
914914

915915
dtype = sympy_dtype(pivot, base=meta.dtype)
916-
obj = make(name=name, dimensions=dimensions, halo=halo, dtype=dtype)
916+
obj = make(name=name, dimensions=dimensions, halo=halo, dtype=dtype,
917+
shift=shift)
917918
expression = Eq(obj[indices], uxreplace(pivot, subs))
918919

919920
callback = lambda idx: obj[[i + s for i, s in zip(idx, shift)]]

devito/types/misc.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from devito.types import Array, CompositeObject, Indexed, Symbol, LocalObject
1212
from devito.types.basic import IndexedData
13-
from devito.tools import CustomDtype, Pickable, frozendict
13+
from devito.tools import CustomDtype, Pickable, as_tuple, frozendict
1414

1515
__all__ = ['Timer', 'Pointer', 'VolatileInt', 'FIndexed', 'Wildcard', 'Fence',
1616
'Global', 'Hyperplane', 'Indirection', 'Temp', 'TempArray', 'Jump',
@@ -235,12 +235,25 @@ class TempArray(Array):
235235

236236
is_autopaddable = True
237237

238+
__rkwargs__ = (Array.__rkwargs__ + ('shift',))
239+
240+
def __init_finalize__(self, *args, shift=None, **kwargs):
241+
super().__init_finalize__(*args, **kwargs)
242+
243+
# An integer for each Dimension representing the shift applied to the halo
244+
# for homogeneity reasons
245+
self._shift = as_tuple(shift)
246+
238247
def __padding_setup__(self, **kwargs):
239248
padding = kwargs.pop('padding', None)
240249
if padding is None:
241250
padding = self.__padding_setup_smart__(**kwargs)
242251
return super().__padding_setup__(padding=padding, **kwargs)
243252

253+
@property
254+
def shift(self):
255+
return self._shift
256+
244257

245258
class Fence:
246259

0 commit comments

Comments
 (0)