@@ -95,13 +95,36 @@ def dtype(self):
9595
9696 @cached_property
9797 def indices (self ):
98- return tuple (filter_ordered (flatten (getattr (i , 'indices' , ())
99- for i in self ._args_diff )))
98+ if not self ._args_diff :
99+ return DimensionTuple ()
100+
101+ # Get indices of all args and merge them
102+ mapper = {}
103+ for a in self ._args_diff :
104+ for d , i in a .indices .getters .items ():
105+ mapper .setdefault (d , []).append (i )
106+
107+ # Filter unique indices
108+ mapper = {k : v [0 ] if len (v ) == 1 else tuple (filter_ordered (v ))
109+ for k , v in mapper .items ()}
110+
111+ return DimensionTuple (* mapper .values (), getters = tuple (mapper .keys ()))
100112
101113 @cached_property
102114 def dimensions (self ):
103- return tuple (filter_ordered (flatten (getattr (i , 'dimensions' , ())
104- for i in self ._args_diff )))
115+ if not self ._args_diff :
116+ return DimensionTuple ()
117+
118+ # Use the staggering of the highest priority function
119+ return highest_priority (self ).dimensions
120+
121+ @cached_property
122+ def staggered (self ):
123+ if not self ._args_diff :
124+ return None
125+
126+ # Use the staggering of the highest priority function
127+ return highest_priority (self ).staggered
105128
106129 @cached_property
107130 def root_dimensions (self ):
@@ -117,11 +140,6 @@ def indices_ref(self):
117140 return DimensionTuple (* self .dimensions , getters = self .dimensions )
118141 return highest_priority (self ).indices_ref
119142
120- @cached_property
121- def staggered (self ):
122- return tuple (filter_ordered (flatten (getattr (i , 'staggered' , ())
123- for i in self ._args_diff )))
124-
125143 @cached_property
126144 def is_Staggered (self ):
127145 return any ([getattr (i , 'is_Staggered' , False ) for i in self ._args_diff ])
@@ -474,13 +492,21 @@ def has_free(self, *patterns):
474492 return all (i in self .free_symbols for i in patterns )
475493
476494
477- def highest_priority (DiffOp ):
495+ def highest_priority (diff_op ):
496+ if not diff_op ._args_diff :
497+ return diff_op
498+
478499 # We want to get the object with highest priority
479500 # We also need to make sure that the object with the largest
480501 # set of dimensions is used when multiple ones with the same
481502 # priority appear
482503 prio = lambda x : (getattr (x , '_fd_priority' , 0 ), len (x .dimensions ))
483- return sorted (DiffOp ._args_diff , key = prio , reverse = True )[0 ]
504+ prio_func = sorted (diff_op ._args_diff , key = prio , reverse = True )[0 ]
505+
506+ # The highest priority must be a Function
507+ if not isinstance (prio_func , AbstractFunction ):
508+ return highest_priority (prio_func )
509+ return prio_func
484510
485511
486512class DifferentiableOp (Differentiable ):
@@ -548,8 +574,11 @@ class DifferentiableFunction(DifferentiableOp):
548574 def __new__ (cls , * args , ** kwargs ):
549575 return cls .__sympy_class__ .__new__ (cls , * args , ** kwargs )
550576
551- def _eval_at (self , func ):
552- return self
577+ @property
578+ def _fd_priority (self ):
579+ if highest_priority (self ) is self :
580+ return super ()._fd_priority
581+ return highest_priority (self )._fd_priority
553582
554583
555584class Add (DifferentiableOp , sympy .Add ):
@@ -633,26 +662,12 @@ def _gather_for_diff(self):
633662 if len (set (f .staggered for f in self ._args_diff )) == 1 :
634663 return self
635664
636- func_args = highest_priority (self )
637- new_args = []
638- ref_inds = func_args .indices_ref .getters
639-
640- for f in self .args :
641- if f not in self ._args_diff \
642- or f is func_args \
643- or isinstance (f , DifferentiableFunction ):
644- new_args .append (f )
645- else :
646- ind_f = f .indices_ref .getters
647- mapper = {ind_f .get (d , d ): ref_inds .get (d , d )
648- for d in self .dimensions
649- if ind_f .get (d , d ) is not ref_inds .get (d , d )}
650- if mapper :
651- new_args .append (f .subs (mapper ))
652- else :
653- new_args .append (f )
654-
655- return self .func (* new_args , evaluate = False )
665+ derivs , other = split (self .args , lambda a : isinstance (a , sympy .Derivative ))
666+ if len (derivs ) == 0 :
667+ return self ._eval_at (highest_priority (self ))
668+ else :
669+ other = self .func (* other )._eval_at (highest_priority (self ))
670+ return self .func (other , * derivs )
656671
657672
658673class Pow (DifferentiableOp , sympy .Pow ):
@@ -1034,6 +1049,9 @@ def __new__(cls, *args, base=None, **kwargs):
10341049 obj = super ().__new__ (cls , * args , ** kwargs )
10351050
10361051 try :
1052+ if base is obj :
1053+ # In some rare cases (rebuild?) base may be obj itself
1054+ base = base .base
10371055 obj .base = base
10381056 except AttributeError :
10391057 # This might happen if e.g. one attempts a (re)construction with
@@ -1061,6 +1079,10 @@ def _eval_at(self, func):
10611079 # and should not be re-evaluated at a different location
10621080 return self
10631081
1082+ @property
1083+ def indices_ref (self ):
1084+ return self .base .indices_ref
1085+
10641086
10651087class diffify :
10661088
@@ -1184,6 +1206,29 @@ def _(expr, x0, **kwargs):
11841206 return expr .func (interp_for_fd (expr .expr , x0_expr , ** kwargs ))
11851207
11861208
1209+ @interp_for_fd .register (Mul )
1210+ def _ (expr , x0 , ** kwargs ):
1211+ # For a Mul expression, we interpolate the whole expression
1212+ # Do we actually need interpolation
1213+ if all (expr .indices [d ] is i for d , i in x0 .items ()):
1214+ return expr
1215+
1216+ # Split args between those that need interp and those that don't
1217+ def test0 (a ):
1218+ return all (a .indices [d ] is i for d , i in x0 .items () if d in a .dimensions )
1219+
1220+ oa , ia = split (expr ._args_diff ,
1221+ lambda a : isinstance (a , sympy .Derivative ) or test0 (a ))
1222+ oa = oa + tuple (a for a in expr .args if a not in expr ._args_diff )
1223+
1224+ # Interpolate the necessary args
1225+ d_dims = tuple ((d , 0 ) for d in x0 )
1226+ fd_order = tuple (expr .interp_order for d in x0 )
1227+ iexpr = expr .func (* ia ).diff (* d_dims , fd_order = fd_order , x0 = x0 , ** kwargs )
1228+
1229+ return expr .func (iexpr , * oa )
1230+
1231+
11871232@interp_for_fd .register (sympy .Expr )
11881233def _ (expr , x0 , ** kwargs ):
11891234 if expr .args :
@@ -1194,7 +1239,8 @@ def _(expr, x0, **kwargs):
11941239
11951240@interp_for_fd .register (AbstractFunction )
11961241def _ (expr , x0 , ** kwargs ):
1197- x0_expr = {d : v for d , v in x0 .items () if v .has (d )}
1242+ x0_expr = {d : v for d , v in x0 .items () if v .has (d )
1243+ and expr .indices [d ] is not v }
11981244 if x0_expr :
11991245 return expr .subs ({expr .indices [d ]: v for d , v in x0_expr .items ()})
12001246 else :
0 commit comments