Skip to content

Commit 8e36307

Browse files
committed
api: fix staggering handling at rebuild
1 parent df444ea commit 8e36307

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

devito/types/dense.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,9 @@ def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):
11431143
* 0 to non-staggered dimensions;
11441144
* 1 to staggered dimensions.
11451145
"""
1146+
if isinstance(staggered, Staggering):
1147+
staggered = staggered._ref
1148+
11461149
if not staggered:
11471150
processed = ()
11481151
elif staggered is CELL:

devito/types/tensor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@
2020
__all__ = ['TensorFunction', 'TensorTimeFunction', 'VectorFunction', 'VectorTimeFunction']
2121

2222

23+
def staggering(stagg, i, j, d, dims):
24+
if stagg is None:
25+
# No input
26+
return NODE if i == j else (d, dims[j])
27+
elif isinstance(stagg, (tuple, list)):
28+
# User input as list or tuple
29+
return stagg[i][j]
30+
elif isinstance(stagg, AbstractTensor):
31+
# From rebuild/tensor property. Indexed as a sympy Matrix
32+
return stagg[i, j]
33+
34+
2335
class TensorFunction(AbstractTensor):
2436
"""
2537
Tensor valued Function represented as a Matrix.
@@ -128,8 +140,7 @@ def __subfunc_setup__(cls, *args, **kwargs):
128140
start = i if (symm or diag) else 0
129141
stop = i + 1 if diag else len(dims)
130142
for j in range(start, stop):
131-
staggj = (stagg[i][j] if stagg is not None
132-
else (NODE if i == j else (d, dims[j])))
143+
staggj = staggering(stagg, i, j, d, dims)
133144
sub_kwargs = cls._component_kwargs((i, j), **kwargs)
134145
sub_kwargs.update({'name': f"{name}_{d.name}{dims[j].name}",
135146
'staggered': staggj})

devito/types/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ class Staggering(DimensionTuple):
5757
def on_node(self):
5858
return not self or all(s == 0 for s in self)
5959

60+
@property
61+
def _ref(self):
62+
if self.on_node:
63+
return NODE
64+
else:
65+
return tuple(d for d, s in zip(self.getters, self, strict=True) if s == 1)
66+
6067

6168
class IgnoreDimSort(tuple):
6269
"""A tuple subclass used to wrap the implicit_dims to indicate

0 commit comments

Comments
 (0)