-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathAbstractDifferentiationEnzymeExt.jl
More file actions
61 lines (57 loc) · 2.05 KB
/
AbstractDifferentiationEnzymeExt.jl
File metadata and controls
61 lines (57 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
module AbstractDifferentiationEnzymeExt
if isdefined(Base, :get_extension)
import AbstractDifferentiation as AD
using Enzyme: Enzyme
else
import ..AbstractDifferentiation as AD
using ..Enzyme: Enzyme
end
AD.@primitive function jacobian(b::AD.EnzymeForwardBackend, f, x)
val = f(x)
if val isa Real
return adjoint.(AD.gradient(b, f, x))
else
if length(x) == 1 && length(val) == 1
# Enzyme.jacobian returns a vector of length 1 in this case
return (Matrix(adjoint(Enzyme.jacobian(Enzyme.Forward, f, x))),)
else
return (Enzyme.jacobian(Enzyme.Forward, f, x),)
end
end
end
function AD.jacobian(b::AD.EnzymeForwardBackend, f, x::Real)
return AD.derivative(b, f, x)
end
function AD.gradient(::AD.EnzymeForwardBackend, f, x::AbstractArray)
# Enzyme.gradient with Forward returns a tuple of the same length as the input
return ([Enzyme.gradient(Enzyme.Forward, f, x)...],)
end
function AD.gradient(b::AD.EnzymeForwardBackend, f, x::Real)
return AD.derivative(b, f, x)
end
function AD.derivative(::AD.EnzymeForwardBackend, f, x::Number)
# Enzyme.gradient with Forward returns a tuple of the same length as the input
return Enzyme.gradient(Enzyme.Forward, x -> f(x[1]), [x])
end
AD.@primitive function jacobian(::AD.EnzymeReverseBackend, f, x)
val = f(x)
if val isa Real
return (adjoint(Enzyme.gradient(Enzyme.Reverse, f, x)),)
else
if length(x) == 1 && length(val) == 1
# Enzyme.jacobian returns an adjoint vector of length 1 in this case
return (Matrix(Enzyme.jacobian(Enzyme.Reverse, f, x, Val(1))),)
else
return (Enzyme.jacobian(Enzyme.Reverse, f, x, Val(length(val))),)
end
end
end
function AD.gradient(::AD.EnzymeReverseBackend, f, x::AbstractArray)
dx = similar(x)
Enzyme.gradient!(Enzyme.Reverse, dx, f, x)
return (dx,)
end
function AD.derivative(::AD.EnzymeReverseBackend, f, x::Number)
(AD.gradient(AD.EnzymeReverseBackend(), x -> f(x[1]), [x])[1][1],)
end
end # module