66from devito .finite_differences import Derivative
77from devito .finite_differences .differentiable import diff2sympy
88from devito .ir .equations import LoweredEq
9- from devito .passes .equations .linearity import collect_derivatives
9+ from devito .passes .equations .linearity import (
10+ _collect_derivatives as collect_derivatives
11+ )
1012from devito .tools import timed_region
1113
1214
@@ -34,8 +36,9 @@ def test_nocollection_if_diff_dims(self):
3436
3537 # Since all Function are time dependent, there should be no collection
3638 # and produce the same result as with the pre evaluated expression
37- expr = Operator ._lower_exprs ([eq ], options = {})[0 ]
38- expr2 = Operator ._lower_exprs ([eq .evaluate ], options = {})[0 ]
39+ options = {'deriv-collect' : True }
40+ expr = Operator ._lower_exprs ([eq ], options = options )[0 ]
41+ expr2 = Operator ._lower_exprs ([eq .evaluate ], options = options )[0 ]
3942
4043 assert expr == expr2
4144
@@ -45,7 +48,7 @@ def test_numeric_constant(self):
4548 u = TimeFunction (name = "u" , grid = grid , space_order = 4 , time_order = 2 )
4649
4750 eq = Eq (u .forward , u .dx .dx + 0.3 * u .dy .dx )
48- leq = collect_derivatives . func ([eq ])[0 ]
51+ leq = collect_derivatives ([eq ])[0 ]
4952
5053 assert len (leq .find (Derivative )) == 3
5154
@@ -56,7 +59,7 @@ def test_symbolic_constant(self):
5659 u = TimeFunction (name = "u" , grid = grid , space_order = 4 , time_order = 2 )
5760
5861 eq = Eq (u .forward , u .dx .dx + dt ** 0.2 * u .dy .dx )
59- leq = collect_derivatives . func ([eq ])[0 ]
62+ leq = collect_derivatives ([eq ])[0 ]
6063
6164 assert len (leq .find (Derivative )) == 3
6265
@@ -69,7 +72,7 @@ def test_symbolic_constant_times_add(self):
6972
7073 eq = Eq (u .forward , u .laplace + dt ** 0.2 * u .biharmonic (1 / f ))
7174
72- leq = collect_derivatives . func ([eq ])[0 ]
75+ leq = collect_derivatives ([eq ])[0 ]
7376
7477 assert len (eq .rhs .args ) == 3
7578 assert len (leq .rhs .args ) == 2
@@ -86,7 +89,7 @@ def test_solve(self):
8689
8790 pde = u .dt2 - (u .dx .dx + u .dy .dy ) - u .dx .dy
8891 eq = Eq (u .forward , solve (pde , u .forward ))
89- leq = collect_derivatives . func ([eq ])[0 ]
92+ leq = collect_derivatives ([eq ])[0 ]
9093
9194 assert len (eq .rhs .find (Derivative )) == 5
9295 assert len (leq .rhs .find (Derivative )) == 4
@@ -99,7 +102,7 @@ def test_nocollection_if_unworthy(self):
99102 u = TimeFunction (name = "u" , grid = grid )
100103
101104 eq = Eq (u .forward , (0.4 + dt )* (u .dx + u .dy ))
102- leq = collect_derivatives . func ([eq ])[0 ]
105+ leq = collect_derivatives ([eq ])[0 ]
103106
104107 assert eq == leq
105108
@@ -112,7 +115,7 @@ def test_pull_and_collect(self):
112115 v = TimeFunction (name = "v" , grid = grid )
113116
114117 eq = Eq (u .forward , ((0.4 + dt )* u .dx + 0.3 )* hx + v .dx )
115- leq = collect_derivatives . func ([eq ])[0 ]
118+ leq = collect_derivatives ([eq ])[0 ]
116119
117120 assert eq != leq
118121 args = leq .rhs .args
@@ -129,7 +132,7 @@ def test_pull_and_collect_nested(self):
129132 v = TimeFunction (name = "v" , grid = grid , space_order = 2 )
130133
131134 eq = Eq (u .forward , (((0.4 + dt )* u .dx + 0.3 )* hx + v .dx ).dy + (0.2 + hy )* v .dy )
132- leq = collect_derivatives . func ([eq ])[0 ]
135+ leq = collect_derivatives ([eq ])[0 ]
133136
134137 assert eq != leq
135138 assert leq .rhs == ((v + hx * (0.4 + dt )* u ).dx + 0.3 * hx + (0.2 + hy )* v ).dy
@@ -143,7 +146,7 @@ def test_pull_and_collect_nested_v2(self):
143146 v = TimeFunction (name = "v" , grid = grid , space_order = 2 )
144147
145148 eq = Eq (u .forward , ((0.4 + dt * (hy + 1. + hx * hy ))* u .dx + 0.3 )* hx + v .dx )
146- leq = collect_derivatives . func ([eq ])[0 ]
149+ leq = collect_derivatives ([eq ])[0 ]
147150
148151 assert eq != leq
149152 assert leq .rhs == 0.3 * hx + (hx * (0.4 + dt * (hy + 1. + hx * hy ))* u + v ).dx
@@ -158,7 +161,7 @@ def test_pull_and_collect_nested_v3(self):
158161 v = TimeFunction (name = "v" , grid = grid , space_order = 2 )
159162
160163 eq = Eq (u .forward , 0.4 + a * (hx + dt * (u .dx + v .dx )))
161- leq = collect_derivatives . func ([eq ])[0 ]
164+ leq = collect_derivatives ([eq ])[0 ]
162165
163166 assert eq != leq
164167 assert leq .rhs == 0.4 + a * (hx + (dt * u + dt * v ).dx )
@@ -172,7 +175,7 @@ def test_nocollection_subdims(self):
172175 f = Function (name = 'f' , grid = grid )
173176
174177 eq = Eq (u .forward , u .dx + 0.2 * f [xi , yi ]* v .dx )
175- leq = collect_derivatives . func ([eq ])[0 ]
178+ leq = collect_derivatives ([eq ])[0 ]
176179
177180 assert eq == leq
178181
@@ -184,7 +187,7 @@ def test_nocollection_staggered(self):
184187 v = TimeFunction (name = "v" , grid = grid , staggered = x )
185188
186189 eq = Eq (u .forward , u .dx + v .dx )
187- leq = collect_derivatives . func ([eq ])[0 ]
190+ leq = collect_derivatives ([eq ])[0 ]
188191
189192 assert eq == leq
190193
@@ -195,13 +198,13 @@ def test_nocollection_mixed_order(self):
195198
196199 # First case is obvious...
197200 eq = Eq (u .forward , u .dx2 + u .dx .dy + 1. )
198- leq = collect_derivatives . func ([eq ])[0 ]
201+ leq = collect_derivatives ([eq ])[0 ]
199202
200203 assert eq == leq
201204
202205 # y-derivative should not get collected!
203206 eq = Eq (u .forward , u .dy2 + u .dx .dy + 1. )
204- leq = collect_derivatives . func ([eq ])[0 ]
207+ leq = collect_derivatives ([eq ])[0 ]
205208
206209 assert eq == leq
207210
0 commit comments