Skip to content

Commit 039ddd0

Browse files
committed
tests: Add a vector grad test and simplify existing tests
1 parent 9f93edd commit 039ddd0

1 file changed

Lines changed: 22 additions & 20 deletions

File tree

tests/test_derivatives.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from devito import (Grid, Function, TimeFunction, Eq, Operator, NODE, cos, sin,
66
ConditionalDimension, left, right, centered, div, grad,
7-
curl, laplace, VectorFunction)
7+
curl, laplace, VectorFunction, TensorFunction)
88
from devito.finite_differences import Derivative, Differentiable, diffify
99
from devito.finite_differences.differentiable import (Add, EvalDerivative, IndexSum,
1010
IndexDerivative, Weights,
@@ -636,11 +636,22 @@ def test_grad_w_side(self, side):
636636
expr1 = VectorFunction(name=f"{f.name}_vec", space_order=f.space_order,
637637
components=comps, grid=grid).evaluate
638638

639-
expr2 = f.grad(side=side).evaluate
640-
expr3 = grad(f, side=side).evaluate
639+
assert expr1 == f.grad(side=side).evaluate
640+
assert expr1 == grad(f, side=side).evaluate
641641

642-
assert expr1 == expr2
643-
assert expr1 == expr3
642+
@pytest.mark.parametrize('side', [left, right, centered])
643+
def test_vector_grad_w_side(self, side):
644+
grid = Grid(shape=(11, 11))
645+
f = VectorFunction(name='f', grid=grid, space_order=2, staggered=(None, None))
646+
647+
comps = ((f[0].dx(side=side), f[0].dy(side=side)),
648+
(f[1].dx(side=side), f[1].dy(side=side)))
649+
650+
expr1 = TensorFunction(name=f"{f.name}_tens", space_order=f.space_order,
651+
components=comps, grid=grid).evaluate
652+
653+
assert expr1 == f.grad(side=side).evaluate
654+
assert expr1 == grad(f, side=side).evaluate
644655

645656
@pytest.mark.parametrize('side', [left, right, centered])
646657
def test_div_w_side(self, side):
@@ -649,11 +660,8 @@ def test_div_w_side(self, side):
649660

650661
expr1 = (f[0].dx(side=side) + f[1].dy(side=side)).evaluate
651662

652-
expr2 = f.div(side=side).evaluate
653-
expr3 = div(f, side=side).evaluate
654-
655-
assert expr1 == expr2
656-
assert expr1 == expr3
663+
assert expr1 == f.div(side=side).evaluate
664+
assert expr1 == div(f, side=side).evaluate
657665

658666
@pytest.mark.parametrize('side', [left, right, centered])
659667
def test_curl_w_side(self, side):
@@ -668,11 +676,8 @@ def test_curl_w_side(self, side):
668676
expr1 = VectorFunction(name=f"{f.name}_vec", space_order=f.space_order,
669677
components=comps, grid=grid).evaluate
670678

671-
expr2 = f.curl(side=side).evaluate
672-
expr3 = curl(f, side=side).evaluate
673-
674-
assert expr1 == expr2
675-
assert expr1 == expr3
679+
assert expr1 == f.curl(side=side).evaluate
680+
assert expr1 == curl(f, side=side).evaluate
676681

677682
@pytest.mark.parametrize('side', [left, right, centered])
678683
def test_laplace_w_side(self, side):
@@ -681,11 +686,8 @@ def test_laplace_w_side(self, side):
681686

682687
expr1 = (f.dx2(side=side) + f.dy2(side=side)).evaluate
683688

684-
expr2 = f.laplacian(side=side).evaluate
685-
expr3 = laplace(f, side=side).evaluate
686-
687-
assert expr1 == expr2
688-
assert expr1 == expr3
689+
assert expr1 == f.laplacian(side=side).evaluate
690+
assert expr1 == laplace(f, side=side).evaluate
689691

690692
def test_substitution(self):
691693
grid = Grid((11, 11))

0 commit comments

Comments
 (0)