Skip to content

Commit 28e0b91

Browse files
committed
api: fix interpolate with complex dtype
1 parent 6a3c7b3 commit 28e0b91

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

devito/passes/clusters/cse.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def retrieve_ctemps(exprs, mode='all'):
3636
return search(exprs, lambda expr: isinstance(expr, CTemp), mode, 'dfs')
3737

3838

39+
def cse_dtype(exprdtype, cdtype):
40+
"""
41+
Return the dtype of a CSE temporary given the dtype of the expression to be
42+
captured and the cluster's dtype.
43+
"""
44+
if np.issubdtype(cdtype, np.complexfloating):
45+
return np.promote_types(exprdtype, cdtype(0).real.__class__).type
46+
else:
47+
# Real cluster, can safely promote to the largest precision
48+
return np.promote_types(exprdtype, cdtype).type
49+
50+
3951
@cluster_pass
4052
def cse(cluster, sregistry=None, options=None, **kwargs):
4153
"""
@@ -86,7 +98,7 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
8698
if cluster.is_fence:
8799
return cluster
88100

89-
make_dtype = lambda e: np.promote_types(e.dtype, dtype).type
101+
make_dtype = lambda e: cse_dtype(e.dtype, dtype)
90102
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
91103

92104
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)

tests/test_interpolation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,21 @@ def test_point_symbol_types(dtype, expected):
855855
assert point_symbol.dtype is expected
856856

857857

858+
@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
859+
def test_interp_complex(dtype):
860+
grid = Grid((11, 11, 11))
861+
862+
sc = SparseFunction(name="sc", grid=grid, npoint=1, dtype=dtype)
863+
sc.coordinates.data[:] = [.5, .5, .5]
864+
865+
fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype)
866+
fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape)
867+
opC = Operator([sc.interpolate(expr=fc)], name="OpC")
868+
opC()
869+
870+
assert np.isclose(sc.data[0], fc.data[5, 5, 5])
871+
872+
858873
class SD0(SubDomain):
859874
name = 'sd0'
860875

0 commit comments

Comments
 (0)