Skip to content

Commit 12fc9c9

Browse files
authored
Merge pull request #2828 from devitocodes/stagg-time
api: fix evaluation with different time dims
2 parents ca50531 + fcea800 commit 12fc9c9

4 files changed

Lines changed: 25 additions & 3 deletions

File tree

devito/symbolics/manipulation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from devito.symbolics.unevaluation import Mul as UnevalMul
1717
from devito.symbolics.unevaluation import Pow as UnevalPow
1818
from devito.symbolics.unevaluation import UnevaluableMixin
19-
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
19+
from devito.tools import (
20+
EnrichedTuple, as_list, as_tuple, flatten, split, transitive_closure
21+
)
2022
from devito.types.array import ComponentAccess
2123
from devito.types.basic import Basic, Indexed
2224
from devito.types.equation import Eq
@@ -130,6 +132,12 @@ def _(iterable, rule):
130132
return iterable.__class__(ret), changed
131133

132134

135+
@_uxreplace_dispatch.register(EnrichedTuple)
136+
def _(iterable, rule):
137+
retval, changed = _uxreplace_dispatch(tuple(iterable), rule)
138+
return iterable.__class__(*retval, getters=iterable.getters), changed
139+
140+
133141
@_uxreplace_dispatch.register(dict)
134142
def _(mapper, rule):
135143
ret = {}

devito/types/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def origin(self):
971971
@property
972972
def dimensions(self):
973973
"""Tuple of Dimensions representing the object indices."""
974-
return self._dimensions
974+
return DimensionTuple(*self._dimensions, getters=self._dimensions)
975975

976976
@cached_property
977977
def space_dimensions(self):

devito/types/dense.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,8 @@ def _eval_at(self, func):
11281128
for d in self.dimensions:
11291129
try:
11301130
if self.indices_ref[d] is not func.indices_ref[d]:
1131-
mapper[self.indices_ref[d]] = func.indices_ref[d]
1131+
f_idx = func.indices_ref[d]._subs(func.dimensions[d], d)
1132+
mapper[self.indices_ref[d]] = f_idx
11321133
except KeyError:
11331134
pass
11341135

tests/test_staggered_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,16 @@ def test_staggered_rebuild(stagg):
187187
assert f2.indices[nd] == nd + nd.spacing / 2
188188
else:
189189
assert f2.indices[nd] == nd
190+
191+
192+
def test_eval_at_different_dim():
193+
grid = Grid(shape=(31, 17, 25))
194+
nt = 5
195+
x, _, _ = grid.dimensions
196+
197+
v = TimeFunction(name="v", grid=grid, staggered=x)
198+
tau = TimeFunction(name="tau", grid=grid, save=nt)
199+
200+
eq = Eq(tau.forward, v).evaluate
201+
202+
assert grid.time_dim not in eq.rhs.free_symbols

0 commit comments

Comments
 (0)