diff --git a/src/compiler.jl b/src/compiler.jl index 4e2a1422db..dc10fd7b91 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1111,12 +1111,15 @@ end EnumAttribute("willreturn"), EnumAttribute("nosync"), EnumAttribute("nofree"), + StringAttribute("enzyme_preserve_primal", "*"), ] else LLVM.Attribute[EnumAttribute("memory", NoEffects.data), StringAttribute("enzyme_shouldrecompute"), EnumAttribute("willreturn"), EnumAttribute("nosync"), - EnumAttribute("nofree")] + EnumAttribute("nofree"), + StringAttribute("enzyme_preserve_primal", "*"), + ] end handleCustom(state, custom, k_name, llvmfn, name, attrs) return @@ -6795,19 +6798,31 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri mstr = if job.config.params.ABI <: InlineABI "" else + fixup_callconv!(mod, JIT.get_tm()) + for f in functions(mod) + for i in 1:length(parameters(f)) + for a in collect(parameter_attributes(f, i)) + if kind(a) == "enzyme_sret" + API.EnzymeDumpValueRef(f) + end + @assert kind(a) != "enzyme_sret" + @assert kind(a) != "enzyme_sret_v" + end + end + end string(mod) end if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI if DumpPrePostOpt[] API.EnzymeDumpModuleRef(mod.ref) end - post_optimize!(mod, JIT.get_tm()) + post_optimize!(mod, JIT.get_tm(); callconv=false) if DumpPostOpt[] API.EnzymeDumpModuleRef(mod.ref) end else propagate_returned!(mod) - Compiler.JIT.prepare!(mod) + Compiler.JIT.prepare!(mod) end mstr else diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index d45761fc4f..d4927a31a2 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -392,11 +392,10 @@ end const DumpPreCallConv = Ref(false) const DumpPostCallConv = Ref(false) -function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true) +function fixup_callconv!(mod::LLVM.Module, tm::LLVM.TargetMachine) addr13NoAlias(mod) removeDeadArgs!(mod, tm, #=post_gc_fixup=#false) - memcpy_sret_split!(mod) # if we did the move_sret_tofrom_roots, we will have loaded out of the sret, then stored into the rooted. @@ -451,6 +450,25 @@ function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool ), ) end + return +end + +function post_optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine, machine::Bool = true; callconv::Bool = true) + if callconv + fixup_callconv!(mod, tm) + end + + for f in functions(mod) + if isempty(blocks(f)) + continue + end + if has_fn_attr(f, StringAttribute("enzyme_preserve_primal")) + delete!(LLVM.function_attributes(f), StringAttribute("enzyme_preserve_primal")) + end + end + + removeDeadArgs!(mod, tm, #=post_gc_fixup=#true) + @dispose pb = NewPMPassBuilder() begin registerEnzymeAndPassPipeline!(pb) register!(pb, ReinsertGCMarkerPass()) diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index 167fc250b1..e82e905520 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -2637,7 +2637,16 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup ) for u in LLVM.uses(fn) u = LLVM.user(u) - @assert isa(u, LLVM.CallInst) + if !isa(u, LLVM.CallInst) + # TODO investigate if the inttoptr store that comes from reference caller poses an issue. + continue + msg = sprint() do io + println(io, "Unknown user of fn: ", string(u)) + println(io, "fn: ", string(fn)) + println(io, "mod: ", string(LLVM.parent(fn))) + end + throw(AssertionError(msg)) + end B = IRBuilder() nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) position!(B, nextInst) @@ -2674,7 +2683,26 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine, post_gc_fixup for u in LLVM.uses(fn) u = LLVM.user(u) if isa(u, LLVM.ConstantExpr) - u = LLVM.user(only(LLVM.uses(u))) + for u in LLVM.uses(u) + u = LLVM.user(u) + if !isa(u, LLVM.CallInst) + continue + end + @assert isa(u, LLVM.CallInst) + B = IRBuilder() + nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u)) + position!(B, nextInst) + inp = operands(u)[idx] + cl = call!(B, funcT, sfunc, LLVM.Value[inp]) + if isa(value_type(inp), LLVM.PointerType) + LLVM.API.LLVMAddCallSiteAttribute( + cl, + LLVM.API.LLVMAttributeIndex(1), + EnumAttribute("nocapture"), + ) + end + end + continue end if !isa(u, LLVM.CallInst) continue diff --git a/src/utils.jl b/src/utils.jl index d2c5544c2b..fe869d7623 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -568,11 +568,11 @@ function sret_ty(fn::LLVM.Function, idx::Int)::LLVM.LLVMType end if ekind == "enzyme_sret" - ety = parse(UInt, LLVM.value(attr)) - ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety) - ety = LLVM.LLVMType(ety) + ety = parse(UInt, LLVM.value(attr)) + ety = Base.reinterpret(LLVM.API.LLVMTypeRef, ety) + ety = LLVM.LLVMType(ety) if !LLVM.is_opaque(vt) - @assert ety == eltype(vt) + @assert ety == eltype(vt) "Mismatched sret type $(string(fn))\nidx=$idx\nety ($(string(ety))) != eltype(vt) (vt = $(string(vt)))" end return ety diff --git a/test/advanced.jl b/test/advanced.jl index bf60925ad3..228edb5331 100644 --- a/test/advanced.jl +++ b/test/advanced.jl @@ -546,6 +546,16 @@ end sqr(x) = x * x power(x, n) = x^n + + function objective1(x) + objvar = power(x, 2) + return objvar + end + + x0 = 2.1 + res = Enzyme.jacobian(Forward, Const(Enzyme.gradient), Const(Reverse), Const(objective1), x0) + @test res[3][1] ≈ 2.0 + function objective(x) (x1, x2, x3, x4) = x objvar = -4 - -(((((((((((((sqr(x1) + sqr(x2)) + sqr(x3 + x4)) + x3) + sqr(sin(x3))) + sqr(x1) * sqr(x2)) + x4) + sqr(sin(x3))) + sqr(-1 + x4)) + sqr(sqr(x2))) + sqr(sqr(x3) + sqr(x1 + x4))) + sqr(((-4 + sqr(sin(x4))) + sqr(x2) * sqr(x3)) + x1)) + power(sin(x4), 4))) diff --git a/test/ext/staticarrays.jl b/test/ext/staticarrays.jl index c2c55a9aa4..8d8a501935 100644 --- a/test/ext/staticarrays.jl +++ b/test/ext/staticarrays.jl @@ -89,4 +89,21 @@ end @test res ≈ [1.0, 0.0] res = Enzyme.gradient(Enzyme.Forward, unstable_fun, inp)[1] @test res ≈ [1.0, 0.0] -end \ No newline at end of file +end + +function inner_forhess(x) + return tanh.(x) +end + +function for_hess(x) + return sum(inner_forhess(x)) +end + +grad_forhess(x) = autodiff(Reverse, for_hess, Active, Active(x))[1][1] +hess(x) = jacobian(Forward, grad_forhess, x)[1] + +@testset "StaticArrays hessian" begin + x = @SVector zeros(10) + res = [-2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 -2.0] + @test jacobian(Forward, grad_forhess, x)[1] ≈ res +end