Skip to content

Commit 1c77ae5

Browse files
committed
compiler: Add optoption for collecting derivatives
1 parent 22a288c commit 1c77ae5

6 files changed

Lines changed: 47 additions & 25 deletions

File tree

devito/core/cpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def _normalize_kwargs(cls, **kwargs):
7676

7777
# Code generation options for derivatives
7878
o['expand'] = oo.pop('expand', cls.EXPAND)
79+
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
7980
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
8081
o['deriv-unroll'] = oo.pop('deriv-unroll', False)
8182

@@ -151,7 +152,7 @@ class Cpu64AdvOperator(Cpu64OperatorMixin, CoreOperator):
151152
@classmethod
152153
@timed_pass(name='specializing.DSL')
153154
def _specialize_dsl(cls, expressions, **kwargs):
154-
expressions = collect_derivatives(expressions)
155+
expressions = collect_derivatives(expressions, **kwargs)
155156

156157
return expressions
157158

@@ -254,7 +255,7 @@ class Cpu64CustomOperator(Cpu64OperatorMixin, CustomOperator):
254255
@classmethod
255256
def _make_dsl_passes_mapper(cls, **kwargs):
256257
return {
257-
'collect-derivs': collect_derivatives,
258+
'deriv-collect': collect_derivatives,
258259
}
259260

260261
@classmethod
@@ -309,7 +310,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
309310

310311
_known_passes = (
311312
# DSL
312-
'collect-derivs',
313+
'deriv-collect',
313314
# Expressions
314315
'buffering',
315316
# Clusters

devito/core/gpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _normalize_kwargs(cls, **kwargs):
8989

9090
# Code generation options for derivatives
9191
o['expand'] = oo.pop('expand', cls.EXPAND)
92+
o['deriv-collect'] = oo.pop('deriv-collect', cls.DERIV_COLLECT)
9293
o['deriv-schedule'] = oo.pop('deriv-schedule', cls.DERIV_SCHEDULE)
9394
o['deriv-unroll'] = oo.pop('deriv-unroll', False)
9495

@@ -189,7 +190,7 @@ class DeviceAdvOperator(DeviceOperatorMixin, CoreOperator):
189190
@classmethod
190191
@timed_pass(name='specializing.DSL')
191192
def _specialize_dsl(cls, expressions, **kwargs):
192-
expressions = collect_derivatives(expressions)
193+
expressions = collect_derivatives(expressions, **kwargs)
193194

194195
return expressions
195196

@@ -281,7 +282,7 @@ class DeviceCustomOperator(DeviceOperatorMixin, CustomOperator):
281282
@classmethod
282283
def _make_dsl_passes_mapper(cls, **kwargs):
283284
return {
284-
'collect-derivs': collect_derivatives,
285+
'deriv-collect': collect_derivatives,
285286
}
286287

287288
@classmethod
@@ -331,7 +332,7 @@ def _make_iet_passes_mapper(cls, **kwargs):
331332

332333
_known_passes = (
333334
# DSL
334-
'collect-derivs',
335+
'deriv-collect',
335336
# Expressions
336337
'buffering',
337338
# Clusters

devito/core/operator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ class BasicOperator(Operator):
123123
finite-difference derivatives.
124124
"""
125125

126+
DERIV_COLLECT = True
127+
"""
128+
Factorize finite-difference derivatives exploiting the linearity of the FD
129+
operators.
130+
"""
131+
126132
DERIV_SCHEDULE = 'basic'
127133
"""
128134
The schedule to use for the computation of finite-difference derivatives.
@@ -296,7 +302,7 @@ def _specialize_dsl(cls, expressions, **kwargs):
296302
# Call passes
297303
for i in passes:
298304
try:
299-
expressions = passes_mapper[i](expressions)
305+
expressions = passes_mapper[i](expressions, **kwargs)
300306
except KeyError:
301307
pass
302308

devito/passes/equations/linearity.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,23 @@
1111

1212

1313
@timed_pass()
14-
def collect_derivatives(expressions):
14+
def collect_derivatives(expressions, options=None, **kwargs):
1515
"""
1616
Exploit linearity of finite-differences to collect `Derivative`'s of
1717
same type. This may help CIRE creating fewer temporaries while catching
1818
larger redundant sub-expressions.
1919
"""
20+
deriv_collect = options['deriv-collect']
21+
if not deriv_collect:
22+
return expressions
23+
24+
return _collect_derivatives(expressions)
25+
26+
27+
def _collect_derivatives(expressions):
28+
"""
29+
Carry out the bulk of the work for `collect_derivatives`.
30+
"""
2031
processed = []
2132
for e in expressions:
2233
# Track type and number of nested Derivatives

tests/test_dse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2099,7 +2099,7 @@ def test_sum_of_nested_derivatives(self, expr, exp_arrays, exp_ops):
20992099
eqn = Eq(v.forward, eval(expr))
21002100

21012101
op0 = Operator(eqn, opt=('noop', {'openmp': True}))
2102-
op1 = Operator(eqn, opt=('collect-derivs', 'cire-sops', {'openmp': True}))
2102+
op1 = Operator(eqn, opt=('deriv-collect', 'cire-sops', {'openmp': True}))
21032103
op2 = Operator(eqn, opt=('cire-sops', {'openmp': True}))
21042104
op3 = Operator(eqn, opt=('advanced', {'openmp': True}))
21052105

tests/test_lower_exprs.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from devito.finite_differences import Derivative
77
from devito.finite_differences.differentiable import diff2sympy
88
from devito.ir.equations import LoweredEq
9-
from devito.passes.equations.linearity import collect_derivatives
9+
from devito.passes.equations.linearity import (
10+
_collect_derivatives as collect_derivatives
11+
)
1012
from devito.tools import timed_region
1113

1214

@@ -34,8 +36,9 @@ def test_nocollection_if_diff_dims(self):
3436

3537
# Since all Function are time dependent, there should be no collection
3638
# and produce the same result as with the pre evaluated expression
37-
expr = Operator._lower_exprs([eq], options={})[0]
38-
expr2 = Operator._lower_exprs([eq.evaluate], options={})[0]
39+
options = {'deriv-collect': True}
40+
expr = Operator._lower_exprs([eq], options=options)[0]
41+
expr2 = Operator._lower_exprs([eq.evaluate], options=options)[0]
3942

4043
assert expr == expr2
4144

@@ -45,7 +48,7 @@ def test_numeric_constant(self):
4548
u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2)
4649

4750
eq = Eq(u.forward, u.dx.dx + 0.3*u.dy.dx)
48-
leq = collect_derivatives.func([eq])[0]
51+
leq = collect_derivatives([eq])[0]
4952

5053
assert len(leq.find(Derivative)) == 3
5154

@@ -56,7 +59,7 @@ def test_symbolic_constant(self):
5659
u = TimeFunction(name="u", grid=grid, space_order=4, time_order=2)
5760

5861
eq = Eq(u.forward, u.dx.dx + dt**0.2*u.dy.dx)
59-
leq = collect_derivatives.func([eq])[0]
62+
leq = collect_derivatives([eq])[0]
6063

6164
assert len(leq.find(Derivative)) == 3
6265

@@ -69,7 +72,7 @@ def test_symbolic_constant_times_add(self):
6972

7073
eq = Eq(u.forward, u.laplace + dt**0.2*u.biharmonic(1/f))
7174

72-
leq = collect_derivatives.func([eq])[0]
75+
leq = collect_derivatives([eq])[0]
7376

7477
assert len(eq.rhs.args) == 3
7578
assert len(leq.rhs.args) == 2
@@ -86,7 +89,7 @@ def test_solve(self):
8689

8790
pde = u.dt2 - (u.dx.dx + u.dy.dy) - u.dx.dy
8891
eq = Eq(u.forward, solve(pde, u.forward))
89-
leq = collect_derivatives.func([eq])[0]
92+
leq = collect_derivatives([eq])[0]
9093

9194
assert len(eq.rhs.find(Derivative)) == 5
9295
assert len(leq.rhs.find(Derivative)) == 4
@@ -99,7 +102,7 @@ def test_nocollection_if_unworthy(self):
99102
u = TimeFunction(name="u", grid=grid)
100103

101104
eq = Eq(u.forward, (0.4 + dt)*(u.dx + u.dy))
102-
leq = collect_derivatives.func([eq])[0]
105+
leq = collect_derivatives([eq])[0]
103106

104107
assert eq == leq
105108

@@ -112,7 +115,7 @@ def test_pull_and_collect(self):
112115
v = TimeFunction(name="v", grid=grid)
113116

114117
eq = Eq(u.forward, ((0.4 + dt)*u.dx + 0.3)*hx + v.dx)
115-
leq = collect_derivatives.func([eq])[0]
118+
leq = collect_derivatives([eq])[0]
116119

117120
assert eq != leq
118121
args = leq.rhs.args
@@ -129,7 +132,7 @@ def test_pull_and_collect_nested(self):
129132
v = TimeFunction(name="v", grid=grid, space_order=2)
130133

131134
eq = Eq(u.forward, (((0.4 + dt)*u.dx + 0.3)*hx + v.dx).dy + (0.2 + hy)*v.dy)
132-
leq = collect_derivatives.func([eq])[0]
135+
leq = collect_derivatives([eq])[0]
133136

134137
assert eq != leq
135138
assert leq.rhs == ((v + hx*(0.4 + dt)*u).dx + 0.3*hx + (0.2 + hy)*v).dy
@@ -143,7 +146,7 @@ def test_pull_and_collect_nested_v2(self):
143146
v = TimeFunction(name="v", grid=grid, space_order=2)
144147

145148
eq = Eq(u.forward, ((0.4 + dt*(hy + 1. + hx*hy))*u.dx + 0.3)*hx + v.dx)
146-
leq = collect_derivatives.func([eq])[0]
149+
leq = collect_derivatives([eq])[0]
147150

148151
assert eq != leq
149152
assert leq.rhs == 0.3*hx + (hx*(0.4 + dt*(hy + 1. + hx*hy))*u + v).dx
@@ -158,7 +161,7 @@ def test_pull_and_collect_nested_v3(self):
158161
v = TimeFunction(name="v", grid=grid, space_order=2)
159162

160163
eq = Eq(u.forward, 0.4 + a*(hx + dt*(u.dx + v.dx)))
161-
leq = collect_derivatives.func([eq])[0]
164+
leq = collect_derivatives([eq])[0]
162165

163166
assert eq != leq
164167
assert leq.rhs == 0.4 + a*(hx + (dt*u + dt*v).dx)
@@ -172,7 +175,7 @@ def test_nocollection_subdims(self):
172175
f = Function(name='f', grid=grid)
173176

174177
eq = Eq(u.forward, u.dx + 0.2*f[xi, yi]*v.dx)
175-
leq = collect_derivatives.func([eq])[0]
178+
leq = collect_derivatives([eq])[0]
176179

177180
assert eq == leq
178181

@@ -184,7 +187,7 @@ def test_nocollection_staggered(self):
184187
v = TimeFunction(name="v", grid=grid, staggered=x)
185188

186189
eq = Eq(u.forward, u.dx + v.dx)
187-
leq = collect_derivatives.func([eq])[0]
190+
leq = collect_derivatives([eq])[0]
188191

189192
assert eq == leq
190193

@@ -195,13 +198,13 @@ def test_nocollection_mixed_order(self):
195198

196199
# First case is obvious...
197200
eq = Eq(u.forward, u.dx2 + u.dx.dy + 1.)
198-
leq = collect_derivatives.func([eq])[0]
201+
leq = collect_derivatives([eq])[0]
199202

200203
assert eq == leq
201204

202205
# y-derivative should not get collected!
203206
eq = Eq(u.forward, u.dy2 + u.dx.dy + 1.)
204-
leq = collect_derivatives.func([eq])[0]
207+
leq = collect_derivatives([eq])[0]
205208

206209
assert eq == leq
207210

0 commit comments

Comments
 (0)