Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1111,12 +1111,15 @@ end
EnumAttribute("willreturn"),
EnumAttribute("nosync"),
EnumAttribute("nofree"),
StringAttribute("enzyme_preserve_primal", "*"),
]
else
LLVM.Attribute[EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute"),
EnumAttribute("willreturn"),
EnumAttribute("nosync"),
EnumAttribute("nofree")]
EnumAttribute("nofree"),
StringAttribute("enzyme_preserve_primal", "*"),
]
end
handleCustom(state, custom, k_name, llvmfn, name, attrs)
return
Expand Down Expand Up @@ -6795,19 +6798,31 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
mstr = if job.config.params.ABI <: InlineABI
""
else
fixup_callconv!(mod, JIT.get_tm())
for f in functions(mod)
for i in 1:length(parameters(f))
for a in collect(parameter_attributes(f, i))
if kind(a) == "enzyme_sret"
API.EnzymeDumpValueRef(f)
end
@assert kind(a) != "enzyme_sret"
@assert kind(a) != "enzyme_sret_v"
end
end
end
string(mod)
end
if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI
if DumpPrePostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
end
post_optimize!(mod, JIT.get_tm())
post_optimize!(mod, JIT.get_tm(); callconv=false)
if DumpPostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
end
else
propagate_returned!(mod)
Compiler.JIT.prepare!(mod)
Compiler.JIT.prepare!(mod)
end
mstr
else
Expand Down
22 changes: 20 additions & 2 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,10 @@ end
const DumpPreCallConv = Ref(false)
const DumpPostCallConv = Ref(false)

function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true)
function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine)
addr13NoAlias(mod)

removeDeadArgs!(mod, tm, #=post_gc_fixup=#false)


memcpy_sret_split!(mod)
# if we did the move_sret_tofrom_roots, we will have loaded out of the sret, then stored into the rooted.
Expand Down Expand Up @@ -451,6 +450,25 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool
),
)
end
return
end

function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true; callconv::Bool = true)
if callconv
fixup_callconv!(mod, tm)
end

for f in functions(mod)
if isempty(blocks(f))
continue
end
if has_fn_attr(f, StringAttribute("enzyme_preserve_primal"))
delete!(LLVM.function_attributes(f), StringAttribute("enzyme_preserve_primal"))
end
end

removeDeadArgs!(mod, tm, #=post_gc_fixup=#true)

@dispose pb = NewPMPassBuilder() begin
registerEnzymeAndPassPipeline!(pb)
register!(pb, ReinsertGCMarkerPass())
Expand Down
32 changes: 30 additions & 2 deletions src/llvm/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2637,7 +2637,16 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup
)
for u in LLVM.uses(fn)
u = LLVM.user(u)
@assert isa(u, LLVM.CallInst)
if !isa(u, LLVM.CallInst)
# TODO investigate if the inttoptr store that comes from reference caller poses an issue.
continue
msg = sprint() do io
println(io, "Unknown user of fn: ", string(u))
println(io, "fn: ", string(fn))
println(io, "mod: ", string(LLVM.parent(fn)))
end
throw(AssertionError(msg))
end
B = IRBuilder()
nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
position!(B, nextInst)
Expand Down Expand Up @@ -2674,7 +2683,26 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup
for u in LLVM.uses(fn)
u = LLVM.user(u)
if isa(u, LLVM.ConstantExpr)
u = LLVM.user(only(LLVM.uses(u)))
for u in LLVM.uses(u)
u = LLVM.user(u)
if !isa(u, LLVM.CallInst)
continue
end
@assert isa(u, LLVM.CallInst)
B = IRBuilder()
nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
position!(B, nextInst)
inp = operands(u)[idx]
cl = call!(B, funcT, sfunc, LLVM.Value[inp])
if isa(value_type(inp), LLVM.PointerType)
LLVM.API.LLVMAddCallSiteAttribute(
cl,
LLVM.API.LLVMAttributeIndex(1),
EnumAttribute("nocapture"),
)
end
end
continue
end
if !isa(u, LLVM.CallInst)
continue
Expand Down
8 changes: 4 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,11 +568,11 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType
end

if ekind == "enzyme_sret"
ety = parse(UInt, LLVM.value(attr))
ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
ety = LLVM.LLVMType(ety)
ety = parse(UInt, LLVM.value(attr))
ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety)
ety = LLVM.LLVMType(ety)
if !LLVM.is_opaque(vt)
@assert ety == eltype(vt)
@assert ety == eltype(vt) "Mismatched sret type $(string(fn))\nidx=$idx\nety ($(string(ety))) != eltype(vt) (vt = $(string(vt)))"
end

return ety
Expand Down
10 changes: 10 additions & 0 deletions test/advanced.jl
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,16 @@ end
sqr(x) = x * x
power(x, n) = x^n


function objective1(x)
objvar = power(x, 2)
return objvar
end

x0 = 2.1
res = Enzyme.jacobian(Forward, Const(Enzyme.gradient), Const(Reverse), Const(objective1), x0)
@test res[3][1] ≈ 2.0

function objective(x)
(x1, x2, x3, x4) = x
objvar = -4 - -(((((((((((((sqr(x1) + sqr(x2)) + sqr(x3 + x4)) + x3) + sqr(sin(x3))) + sqr(x1) * sqr(x2)) + x4) + sqr(sin(x3))) + sqr(-1 + x4)) + sqr(sqr(x2))) + sqr(sqr(x3) + sqr(x1 + x4))) + sqr(((-4 + sqr(sin(x4))) + sqr(x2) * sqr(x3)) + x1)) + power(sin(x4), 4)))
Expand Down
19 changes: 18 additions & 1 deletion test/ext/staticarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,21 @@ end
@test res ≈ [1.0, 0.0]
res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1]
@test res ≈ [1.0, 0.0]
end
end

function inner_forhess(x)
return tanh.(x)
end

function for_hess(x)
return sum(inner_forhess(x))
end

grad_forhess(x) = autodiff(Reverse, for_hess, Active, Active(x))[1][1]
hess(x) = jacobian(Forward, grad_forhess, x)[1]

@testset "StaticArrays hessian" begin
x = @SVector zeros(10)
res = [-2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0]
@test jacobian(Forward, grad_forhess, x)[1] ≈ res
end
Loading