-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathAbstractDifferentiationEnzymeExt.jl
More file actions
62 lines (57 loc) · 1.75 KB
/
AbstractDifferentiationEnzymeExt.jl
File metadata and controls
62 lines (57 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
module AbstractDifferentiationEnzymeExt
if isdefined(Base, :get_extension)
import AbstractDifferentiation as AD
using Enzyme: Enzyme
else
import ..AbstractDifferentiation as AD
using ..Enzyme: Enzyme
end
struct Mutating{F}
f::F
end
function (f::Mutating)(y, xs...)
y .= f.f(xs...)
return y
end
AD.@primitive function value_and_pullback_function(b::AD.EnzymeReverseBackend, f, xs...)
y = f(xs...)
return y,
Δ -> begin
Δ_xs = zero.(xs)
dup = if y isa Real
if Δ isa Real
Enzyme.Duplicated([y], [Δ])
elseif Δ isa Tuple{Real}
Enzyme.Duplicated([y], [Δ[1]])
else
throw(ArgumentError("Unsupported cotangent type."))
end
else
if Δ isa AbstractArray{<:Real}
Enzyme.Duplicated(y, Δ)
elseif Δ isa Tuple{AbstractArray{<:Real}}
Enzyme.Duplicated(y, Δ[1])
else
throw(ArgumentError("Unsupported cotangent type."))
end
end
Enzyme.autodiff(
Enzyme.Reverse,
Mutating(f),
Enzyme.Const,
dup,
Enzyme.Duplicated.(xs, Δ_xs)...,
)
return Δ_xs
end
end
function AD.pushforward_function(::AD.EnzymeReverseBackend, f, xs...)
return AD.pushforward_function(AD.EnzymeForwardBackend(), f, xs...)
end
AD.@primitive function pushforward_function(b::AD.EnzymeForwardBackend, f, xs...)
return ds -> Tuple(Enzyme.autodiff(Enzyme.Forward, f, Enzyme.Duplicated.(xs, ds)...))
end
function AD.value_and_pullback_function(::AD.EnzymeForwardBackend, f, xs...)
return AD.value_and_pullback_function(AD.EnzymeReverseBackend(), f, xs...)
end
end # module