11# # Pullback
22
3- struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
3+ struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG}
44 _sig:: Val{SIG}
55 cache:: Tcache
6- dy_righttype:: DY
76 args_to_zero:: NTuple{N, Bool}
87end
98
@@ -12,18 +11,14 @@ function DI.prepare_pullback_nokwarg(
1211 ) where {F, C}
1312 _sig = DI. signature (f, backend, x, ty, contexts... ; strict)
1413 config = get_config (backend)
15- cache = prepare_pullback_cache (
16- f, x, map (DI. unwrap, contexts)... ; config. debug_mode, config. silence_debug_messages
17- )
18- y = f (x, map (DI. unwrap, contexts)... )
19- dy_righttype = zero_tangent (y)
14+ cache = prepare_pullback_cache (f, x, map (DI. unwrap, contexts)... ; config)
2015 contexts_tup_false = map (_ -> false , contexts)
2116 args_to_zero = (
2217 false , # f
2318 true , # x
2419 contexts_tup_false... ,
2520 )
26- prep = MooncakeOneArgPullbackPrep (_sig, cache, dy_righttype, args_to_zero)
21+ prep = MooncakeOneArgPullbackPrep (_sig, cache, args_to_zero)
2722 return prep
2823end
2924
@@ -37,10 +32,8 @@ function DI.value_and_pullback(
3732 ) where {F, Y, C}
3833 DI. check_prep (f, prep, backend, x, ty, contexts... )
3934 dy = only (ty)
40- dy_righttype = dy isa tangent_type (Y) ? dy : _copy_to_output!! (prep. dy_righttype, dy)
4135 new_y, (_, new_dx) = value_and_pullback!! (
42- prep. cache, dy_righttype, f, x, map (DI. unwrap, contexts)... ;
43- prep. args_to_zero
36+ prep. cache, dy, f, x, map (DI. unwrap, contexts)... ; prep. args_to_zero
4437 )
4538 return new_y, (_copy_output (new_dx),)
4639end
@@ -55,11 +48,8 @@ function DI.value_and_pullback(
5548 ) where {F, Y, C}
5649 DI. check_prep (f, prep, backend, x, ty, contexts... )
5750 ys_and_tx = map (ty) do dy
58- dy_righttype =
59- dy isa tangent_type (Y) ? dy : _copy_to_output!! (prep. dy_righttype, dy)
6051 y, (_, new_dx) = value_and_pullback!! (
61- prep. cache, dy_righttype, f, x, map (DI. unwrap, contexts)... ;
62- prep. args_to_zero
52+ prep. cache, dy, f, x, map (DI. unwrap, contexts)... ; prep. args_to_zero
6353 )
6454 y, _copy_output (new_dx)
6555 end
@@ -121,9 +111,7 @@ function DI.prepare_gradient_nokwarg(
121111 ) where {F, C}
122112 _sig = DI. signature (f, backend, x, contexts... ; strict)
123113 config = get_config (backend)
124- cache = prepare_gradient_cache (
125- f, x, map (DI. unwrap, contexts)... ; config. debug_mode, config. silence_debug_messages
126- )
114+ cache = prepare_gradient_cache (f, x, map (DI. unwrap, contexts)... ; config)
127115 contexts_tup_false = map (_ -> false , contexts)
128116 args_to_zero = (
129117 false , # f
0 commit comments