Skip to content

Commit c581d00

Browse files
authored
Merge pull request #2854 from devitocodes/tens-stagg-fix
Tens stagg fix
2 parents df444ea + a60b12a commit c581d00

File tree

6 files changed

+62
-12
lines changed

6 files changed

+62
-12
lines changed

devito/data/allocators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def alloc(self, shape, dtype, padding=0):
112112
raise RuntimeError(f"Unable to allocate {size} elements in memory")
113113

114114
# Compute the pointer to the user data
115-
padleft_bytes = padleft * ctypes.sizeof(ctype)
115+
padleft_bytes = int(padleft * ctypes.sizeof(ctype))
116116
c_pointer = ctypes.c_void_p(padleft_pointer.value + padleft_bytes)
117117

118118
# Cast to 1D array of the specified `datasize`

devito/types/basic.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import inspect
3+
import warnings
34
from contextlib import suppress
45
from ctypes import POINTER, Structure, _Pointer, c_char, c_char_p
56
from functools import cached_property, reduce
@@ -9,6 +10,7 @@
910
import sympy
1011
from sympy.core.assumptions import _assume_rules
1112
from sympy.core.decorators import call_highest_priority
13+
from sympy.utilities.exceptions import SymPyDeprecationWarning
1214

1315
from devito.data import default_allocator
1416
from devito.parameters import configuration
@@ -1533,10 +1535,20 @@ def _sympify(self, arg):
15331535
# This is used internally by sympy to process arguments at rebuilt. And since
15341536
# some of our properties are non-sympyfiable we need to have a fallback
15351537
try:
1536-
return super()._sympify(arg)
1537-
except sympy.SympifyError:
1538+
# Pure sympy object
1539+
return arg._sympy_()
1540+
except AttributeError:
15381541
return arg
15391542

1543+
@classmethod
1544+
def _eval_from_dok(cls, rows, cols, dok):
1545+
with warnings.catch_warnings():
1546+
warnings.filterwarnings(
1547+
"ignore",
1548+
category=SymPyDeprecationWarning
1549+
)
1550+
return super()._eval_from_dok(rows, cols, dok)
1551+
15401552
@property
15411553
def grid(self):
15421554
"""

devito/types/dense.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,9 @@ def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):
11431143
* 0 to non-staggered dimensions;
11441144
* 1 to staggered dimensions.
11451145
"""
1146+
if isinstance(staggered, Staggering):
1147+
staggered = staggered._ref
1148+
11461149
if not staggered:
11471150
processed = ()
11481151
elif staggered is CELL:
@@ -1154,6 +1157,8 @@ def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):
11541157
assert len(staggered) == len(dimensions)
11551158
processed = staggered
11561159
else:
1160+
# Staggering is not NODE or CELL or None
1161+
# therefore it's a tuple of dimensions
11571162
processed = []
11581163
for d in dimensions:
11591164
if d in as_tuple(staggered):

devito/types/tensor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@
2020
__all__ = ['TensorFunction', 'TensorTimeFunction', 'VectorFunction', 'VectorTimeFunction']
2121

2222

23+
def staggering(stagg, i, j, d, dims):
24+
if stagg is None:
25+
# No input
26+
return NODE if i == j else (d, dims[j])
27+
elif isinstance(stagg, (tuple, list)):
28+
# User input as list or tuple
29+
return stagg[i][j]
30+
elif isinstance(stagg, AbstractTensor):
31+
# From rebuild/tensor property. Indexed as a sympy Matrix
32+
return stagg[i, j]
33+
34+
2335
class TensorFunction(AbstractTensor):
2436
"""
2537
Tensor valued Function represented as a Matrix.
@@ -128,8 +140,7 @@ def __subfunc_setup__(cls, *args, **kwargs):
128140
start = i if (symm or diag) else 0
129141
stop = i + 1 if diag else len(dims)
130142
for j in range(start, stop):
131-
staggj = (stagg[i][j] if stagg is not None
132-
else (NODE if i == j else (d, dims[j])))
143+
staggj = staggering(stagg, i, j, d, dims)
133144
sub_kwargs = cls._component_kwargs((i, j), **kwargs)
134145
sub_kwargs.update({'name': f"{name}_{d.name}{dims[j].name}",
135146
'staggered': staggj})

devito/types/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ class Staggering(DimensionTuple):
5757
def on_node(self):
5858
return not self or all(s == 0 for s in self)
5959

60+
@property
61+
def _ref(self):
62+
if not self:
63+
return None
64+
elif self.on_node:
65+
return NODE
66+
else:
67+
return tuple(d for d, s in zip(self.getters, self, strict=True) if s == 1)
68+
6069

6170
class IgnoreDimSort(tuple):
6271
"""A tuple subclass used to wrap the implicit_dims to indicate

tests/test_staggered_utils.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from sympy import simplify
44

55
from devito import (
6-
CELL, NODE, Dimension, Eq, Function, Grid, Operator, TimeFunction, VectorTimeFunction,
7-
div
6+
CELL, NODE, Eq, Function, Grid, Operator, TimeFunction, VectorTimeFunction, div
87
)
98
from devito.tools import as_tuple, powerset
109

@@ -173,16 +172,15 @@ def test_staggered_rebuild(stagg):
173172
f = Function(name='f', grid=grid, space_order=4, staggered=stagg)
174173
assert tuple(f.staggered.getters.keys()) == grid.dimensions
175174

176-
new_dims = (Dimension('x1'), Dimension('y1'), Dimension('z1'))
177-
f2 = f.func(dimensions=new_dims)
175+
f2 = f.func(name="f2")
178176

179-
assert f2.dimensions == new_dims
177+
assert f2.dimensions == f.dimensions
180178
assert tuple(f2.staggered) == tuple(f.staggered)
181-
assert tuple(f2.staggered.getters.keys()) == new_dims
179+
assert tuple(f2.staggered.getters.keys()) == f.dimensions
182180

183181
# Check that rebuild correctly set the staggered indices
184182
# with the new dimensions
185-
for (d, nd) in zip(grid.dimensions, new_dims, strict=True):
183+
for (d, nd) in zip(grid.dimensions, f.dimensions, strict=True):
186184
if d in as_tuple(stagg) or stagg is CELL:
187185
assert f2.indices[nd] == nd + nd.spacing / 2
188186
else:
@@ -200,3 +198,18 @@ def test_eval_at_different_dim():
200198
eq = Eq(tau.forward, v).evaluate
201199

202200
assert grid.time_dim not in eq.rhs.free_symbols
201+
202+
203+
def test_new_from_staggering():
204+
grid = Grid(shape=(31, 17, 25))
205+
x, _, _ = grid.dimensions
206+
207+
f = TimeFunction(name="f", grid=grid, staggered=x)
208+
# This used to fail since f.staggered as 4 elements (0, 1, 0, 0)
209+
# but it is processed for Dimension only.
210+
# Now properly converts Staggering to the ref (x,) at init
211+
g = TimeFunction(name="g", grid=grid, staggered=f.staggered)
212+
213+
assert g.staggered._ref == (x,)
214+
assert g.staggered == (0, 1, 0, 0)
215+
assert g.staggered == f.staggered

0 commit comments

Comments
 (0)