@@ -1444,6 +1444,42 @@ def test_fission_for_parallelism(self, exprs, fissioned, shared):
14441444 # Fission happened
14451445 assert i [exp_depth ].dim is exp_dim
14461446
1447+ def test_fission_for_parallelism_b (self ):
1448+ so = 2
1449+ grid = Grid (shape = (10 , 10 , 10 ))
1450+ x , y , z = grid .dimensions
1451+
1452+ f0 = TimeFunction (name = 'f0' , grid = grid , space_order = so , staggered = (x ,))
1453+ f1 = TimeFunction (name = 'f1' , grid = grid , space_order = so , staggered = (y ,))
1454+
1455+ f2 = TimeFunction (name = 'f2' , grid = grid , space_order = so , staggered = (x , z ))
1456+ f3 = TimeFunction (name = 'f3' , grid = grid , space_order = so , staggered = (y , z ))
1457+
1458+ f4 = TimeFunction (name = 'f4' , grid = grid , space_order = so , staggered = NODE )
1459+
1460+ eq0 = Eq (f2 , f0 .dz )
1461+ eq1 = Eq (f3 , f1 .dz )
1462+ eq2 = Eq (f4 , f2 + f3 )
1463+
1464+ op = Operator ([eq0 , eq1 , eq2 ])
1465+
1466+ trees = retrieve_iteration_tree (op )
1467+
1468+ # First two equations should be fused for parallelism, but the third should be
1469+ # fissioned
1470+ assert len (trees ) == 2
1471+ assert len (trees [0 ][- 1 ].nodes [0 ].exprs ) == 2
1472+ assert len (trees [1 ][- 1 ].nodes [0 ].exprs ) == 1
1473+
1474+ def check_expr_contents (expr , expected ):
1475+ assert all (f .base in expr .expr_symbols for f in expected )
1476+
1477+ # Check expressions match equations
1478+ check_expr_contents (trees [0 ][- 1 ].nodes [0 ].exprs [0 ], (f2 , f0 ))
1479+ check_expr_contents (trees [0 ][- 1 ].nodes [0 ].exprs [1 ], (f3 , f1 ))
1480+
1481+ check_expr_contents (trees [1 ][- 1 ].nodes [0 ].exprs [0 ], (f4 , f2 , f3 ))
1482+
14471483 @pytest .mark .parametrize ('exprs' , [
14481484 # 0) Storage related dependence
14491485 ('Eq(u.forward, v)' , 'Eq(v, u.dxl)' ),
0 commit comments