@@ -41,7 +41,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
4141 return 0 ;
4242 }
4343
44- if constexpr (is_autodiff_v <Ta> && is_autodiff_v<Tb> && is_autodiff_v< Td>) {
44+ if constexpr (is_all_autodiff_v <Ta, Tb, Td>) {
4545 arena_t <promote_scalar_t <var, Ta>> arena_A = A.matrix ();
4646 arena_t <promote_scalar_t <var, Tb>> arena_B = B;
4747 arena_t <promote_scalar_t <var, Td>> arena_D = D;
@@ -62,8 +62,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
6262 });
6363
6464 return res;
65- } else if constexpr (is_autodiff_v<
66- Ta> && is_autodiff_v<Tb> && is_constant_v<Td>) {
65+ } else if constexpr (is_all_autodiff_v<Ta, Tb> && is_constant_v<Td>) {
6766 arena_t <promote_scalar_t <var, Ta>> arena_A = A.matrix ();
6867 arena_t <promote_scalar_t <var, Tb>> arena_B = B;
6968 arena_t <promote_scalar_t <double , Td>> arena_D = value_of (D);
@@ -80,8 +79,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
8079 });
8180
8281 return res;
83- } else if constexpr (is_autodiff_v<
84- Ta> && is_constant_v<Tb> && is_autodiff_v<Td>) {
82+ } else if constexpr (is_all_autodiff_v<Ta, Td> && is_constant_v<Tb>) {
8583 arena_t <promote_scalar_t <var, Ta>> arena_A = A.matrix ();
8684 const auto & B_ref = to_ref (B);
8785 arena_t <promote_scalar_t <var, Td>> arena_D = D;
@@ -100,8 +98,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
10098 });
10199
102100 return res;
103- } else if constexpr (is_autodiff_v<
104- Ta> && is_constant_v<Tb> && is_constant_v<Td>) {
101+ } else if constexpr (is_autodiff_v<Ta> && is_constant_all_v<Tb, Td>) {
105102 arena_t <promote_scalar_t <var, Ta>> arena_A = A.matrix ();
106103 const auto & B_ref = to_ref (B);
107104 arena_t <promote_scalar_t <double , Td>> arena_D = value_of (D);
@@ -117,8 +114,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
117114 });
118115
119116 return res;
120- } else if constexpr (is_constant_v<
121- Ta> && is_autodiff_v<Tb> && is_autodiff_v<Td>) {
117+ } else if constexpr (is_constant_v<Ta> && is_all_autodiff_v<Tb, Td>) {
122118 arena_t <promote_scalar_t <var, Tb>> arena_B = B;
123119 arena_t <promote_scalar_t <var, Td>> arena_D = D;
124120 auto AsolveB = to_arena (A.ldlt ().solve (arena_B.val ()));
@@ -136,8 +132,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
136132 });
137133
138134 return res;
139- } else if constexpr (is_constant_v<
140- Ta> && is_autodiff_v<Tb> && is_constant_v<Td>) {
135+ } else if constexpr (is_constant_all_v<Ta, Td> && is_autodiff_v<Tb>) {
141136 arena_t <promote_scalar_t <var, Tb>> arena_B = B;
142137 arena_t <promote_scalar_t <double , Td>> arena_D = value_of (D);
143138 auto AsolveB = to_arena (A.ldlt ().solve (arena_B.val ()));
@@ -149,8 +144,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, LDLT_factor<Ta>& A,
149144 });
150145
151146 return res;
152- } else if constexpr (is_constant_v<
153- Ta> && is_constant_v<Tb> && is_autodiff_v<Td>) {
147+ } else if constexpr (is_constant_all_v<Ta, Tb> && is_autodiff_v<Td>) {
154148 const auto & B_ref = to_ref (B);
155149 arena_t <promote_scalar_t <var, Td>> arena_D = D;
156150 auto BTAsolveB = to_arena (value_of (B_ref).transpose ()
@@ -196,7 +190,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
196190 return 0 ;
197191 }
198192
199- if constexpr (is_autodiff_v <Ta> && is_autodiff_v<Tb> && is_autodiff_v< Td>) {
193+ if constexpr (is_all_autodiff_v <Ta, Tb, Td>) {
200194 arena_t <promote_scalar_t <var, Ta>> arena_A = A.matrix ();
201195 arena_t <promote_scalar_t <var, Tb>> arena_B = B;
202196 arena_t <promote_scalar_t <var, Td>> arena_D = D;
@@ -216,8 +210,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
216210 });
217211
218212 return res;
219- } else if constexpr (is_autodiff_v<
220- Ta> && is_autodiff_v<Tb> && is_constant_v<Td>) {
213+ } else if constexpr (is_all_autodiff_v<Ta, Tb> && is_constant_v<Td>) {
221214 arena_t <promote_scalar_t <var, Ta>> arena_A = A.matrix ();
222215 arena_t <promote_scalar_t <var, Tb>> arena_B = B;
223216 arena_t <promote_scalar_t <double , Td>> arena_D = value_of (D);
@@ -235,8 +228,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
235228 });
236229
237230 return res;
238- } else if constexpr (is_autodiff_v<
239- Ta> && is_constant_v<Tb> && is_autodiff_v<Td>) {
231+ } else if constexpr (is_all_autodiff_v<Ta, Td> && is_constant_v<Tb>) {
240232 arena_t <promote_scalar_t <var, Ta>> arena_A = A.matrix ();
241233 const auto & B_ref = to_ref (B);
242234 arena_t <promote_scalar_t <var, Td>> arena_D = D;
@@ -255,8 +247,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
255247 });
256248
257249 return res;
258- } else if constexpr (is_autodiff_v<
259- Ta> && is_constant_v<Tb> && is_constant_v<Td>) {
250+ } else if constexpr (is_autodiff_v<Ta> && is_constant_all_v<Tb, Td>) {
260251 arena_t <promote_scalar_t <var, Ta>> arena_A = A.matrix ();
261252 const auto & B_ref = to_ref (B);
262253 arena_t <promote_scalar_t <double , Td>> arena_D = value_of (D);
@@ -273,8 +264,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
273264 });
274265
275266 return res;
276- } else if constexpr (is_constant_v<
277- Ta> && is_autodiff_v<Tb> && is_autodiff_v<Td>) {
267+ } else if constexpr (is_constant_v<Ta> && is_all_autodiff_v<Tb, Td>) {
278268 arena_t <promote_scalar_t <var, Tb>> arena_B = B;
279269 arena_t <promote_scalar_t <var, Td>> arena_D = D;
280270 auto AsolveB = to_arena (A.ldlt ().solve (arena_B.val ()));
@@ -291,8 +281,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
291281 });
292282
293283 return res;
294- } else if constexpr (is_constant_v<
295- Ta> && is_autodiff_v<Tb> && is_constant_v<Td>) {
284+ } else if constexpr (is_constant_all_v<Ta, Td> && is_autodiff_v<Tb>) {
296285 arena_t <promote_scalar_t <var, Tb>> arena_B = B;
297286 arena_t <promote_scalar_t <double , Td>> arena_D = value_of (D);
298287 auto AsolveB = to_arena (A.ldlt ().solve (arena_B.val ()));
@@ -305,8 +294,7 @@ inline var trace_gen_inv_quad_form_ldlt(const Td& D, const LDLT_factor<Ta>& A,
305294 });
306295
307296 return res;
308- } else if constexpr (is_constant_v<
309- Ta> && is_constant_v<Tb> && is_autodiff_v<Td>) {
297+ } else if constexpr (is_constant_all_v<Ta, Tb> && is_autodiff_v<Td>) {
310298 const auto & B_ref = to_ref (B);
311299 arena_t <promote_scalar_t <var, Td>> arena_D = D;
312300 auto BTAsolveB = to_arena (value_of (B_ref).transpose ()
0 commit comments