44
55from devito import (Grid , Function , TimeFunction , Eq , Operator , NODE , cos , sin ,
66 ConditionalDimension , left , right , centered , div , grad ,
7- curl , laplace , VectorFunction )
7+ curl , laplace , VectorFunction , TensorFunction )
88from devito .finite_differences import Derivative , Differentiable , diffify
99from devito .finite_differences .differentiable import (Add , EvalDerivative , IndexSum ,
1010 IndexDerivative , Weights ,
@@ -636,11 +636,22 @@ def test_grad_w_side(self, side):
636636 expr1 = VectorFunction (name = f"{ f .name } _vec" , space_order = f .space_order ,
637637 components = comps , grid = grid ).evaluate
638638
639- expr2 = f .grad (side = side ).evaluate
640- expr3 = grad (f , side = side ).evaluate
639+ assert expr1 = = f .grad (side = side ).evaluate
640+ assert expr1 = = grad (f , side = side ).evaluate
641641
642- assert expr1 == expr2
643- assert expr1 == expr3
642+ @pytest .mark .parametrize ('side' , [left , right , centered ])
643+ def test_vector_grad_w_side (self , side ):
644+ grid = Grid (shape = (11 , 11 ))
645+ f = VectorFunction (name = 'f' , grid = grid , space_order = 2 , staggered = (None , None ))
646+
647+ comps = ((f [0 ].dx (side = side ), f [0 ].dy (side = side )),
648+ (f [1 ].dx (side = side ), f [1 ].dy (side = side )))
649+
650+ expr1 = TensorFunction (name = f"{ f .name } _tens" , space_order = f .space_order ,
651+ components = comps , grid = grid ).evaluate
652+
653+ assert expr1 == f .grad (side = side ).evaluate
654+ assert expr1 == grad (f , side = side ).evaluate
644655
645656 @pytest .mark .parametrize ('side' , [left , right , centered ])
646657 def test_div_w_side (self , side ):
@@ -649,11 +660,8 @@ def test_div_w_side(self, side):
649660
650661 expr1 = (f [0 ].dx (side = side ) + f [1 ].dy (side = side )).evaluate
651662
652- expr2 = f .div (side = side ).evaluate
653- expr3 = div (f , side = side ).evaluate
654-
655- assert expr1 == expr2
656- assert expr1 == expr3
663+ assert expr1 == f .div (side = side ).evaluate
664+ assert expr1 == div (f , side = side ).evaluate
657665
658666 @pytest .mark .parametrize ('side' , [left , right , centered ])
659667 def test_curl_w_side (self , side ):
@@ -668,11 +676,8 @@ def test_curl_w_side(self, side):
668676 expr1 = VectorFunction (name = f"{ f .name } _vec" , space_order = f .space_order ,
669677 components = comps , grid = grid ).evaluate
670678
671- expr2 = f .curl (side = side ).evaluate
672- expr3 = curl (f , side = side ).evaluate
673-
674- assert expr1 == expr2
675- assert expr1 == expr3
679+ assert expr1 == f .curl (side = side ).evaluate
680+ assert expr1 == curl (f , side = side ).evaluate
676681
677682 @pytest .mark .parametrize ('side' , [left , right , centered ])
678683 def test_laplace_w_side (self , side ):
@@ -681,11 +686,8 @@ def test_laplace_w_side(self, side):
681686
682687 expr1 = (f .dx2 (side = side ) + f .dy2 (side = side )).evaluate
683688
684- expr2 = f .laplacian (side = side ).evaluate
685- expr3 = laplace (f , side = side ).evaluate
686-
687- assert expr1 == expr2
688- assert expr1 == expr3
689+ assert expr1 == f .laplacian (side = side ).evaluate
690+ assert expr1 == laplace (f , side = side ).evaluate
689691
690692 def test_substitution (self ):
691693 grid = Grid ((11 , 11 ))
0 commit comments