33from sympy import simplify
44
55from devito import (
6- CELL , NODE , Dimension , Eq , Function , Grid , Operator , TimeFunction , VectorTimeFunction ,
7- div
6+ CELL , NODE , Eq , Function , Grid , Operator , TimeFunction , VectorTimeFunction , div
87)
98from devito .tools import as_tuple , powerset
109
@@ -173,16 +172,15 @@ def test_staggered_rebuild(stagg):
173172 f = Function (name = 'f' , grid = grid , space_order = 4 , staggered = stagg )
174173 assert tuple (f .staggered .getters .keys ()) == grid .dimensions
175174
176- new_dims = (Dimension ('x1' ), Dimension ('y1' ), Dimension ('z1' ))
177- f2 = f .func (dimensions = new_dims )
175+ f2 = f .func (name = "f2" )
178176
179- assert f2 .dimensions == new_dims
177+ assert f2 .dimensions == f . dimensions
180178 assert tuple (f2 .staggered ) == tuple (f .staggered )
181- assert tuple (f2 .staggered .getters .keys ()) == new_dims
179+ assert tuple (f2 .staggered .getters .keys ()) == f . dimensions
182180
183181 # Check that rebuild correctly set the staggered indices
184182 # with the new dimensions
185- for (d , nd ) in zip (grid .dimensions , new_dims , strict = True ):
183+ for (d , nd ) in zip (grid .dimensions , f . dimensions , strict = True ):
186184 if d in as_tuple (stagg ) or stagg is CELL :
187185 assert f2 .indices [nd ] == nd + nd .spacing / 2
188186 else :
@@ -200,3 +198,18 @@ def test_eval_at_different_dim():
200198 eq = Eq (tau .forward , v ).evaluate
201199
202200 assert grid .time_dim not in eq .rhs .free_symbols
201+
202+
203+ def test_new_from_staggering ():
204+ grid = Grid (shape = (31 , 17 , 25 ))
205+ x , _ , _ = grid .dimensions
206+
207+ f = TimeFunction (name = "f" , grid = grid , staggered = x )
208+ # This used to fail since f.staggered as 4 elements (0, 1, 0, 0)
209+ # but it is processed for Dimension only.
210+ # Now properly converts Staggering to the ref (x,) at init
211+ g = TimeFunction (name = "g" , grid = grid , staggered = f .staggered )
212+
213+ assert g .staggered ._ref == (x ,)
214+ assert g .staggered == (0 , 1 , 0 , 0 )
215+ assert g .staggered == f .staggered
0 commit comments